{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic regression with PyTorch" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "\n", "# torchvision: popular datasets, model architectures, and common image transformations for computer vision.\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "from random import randint\n", "from random import shuffle\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "'''\n", "Step 1: Prepare dataset\n", "'''\n", "# Use data with only 4 and 9 as labels: which is hardest to classify\n", "label_1, label_2 = 4, 9\n", "\n", "# MNIST training data\n", "train_set = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "\n", "# Use data with two labels\n", "idx = (train_set.targets == label_1) + (train_set.targets == label_2)\n", "train_set.data = train_set.data[idx]\n", "train_set.targets = train_set.targets[idx]\n", "train_set.targets[train_set.targets == label_1] = -1\n", "train_set.targets[train_set.targets == label_2] = 1\n", "\n", "# MNIST testing data\n", "test_set = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor())\n", "\n", "# Use data with two labels\n", "idx = (test_set.targets == label_1) + (test_set.targets == label_2)\n", "test_set.data = test_set.data[idx]\n", "test_set.targets = test_set.targets[idx]\n", "test_set.targets[test_set.targets == label_1] = -1\n", "test_set.targets[test_set.targets == label_2] = 1" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "'''\n", "Step 2: Define the neural network class\n", "'''\n", "class LR(nn.Module) :\n", " '''\n", " Initialize model\n", " input_dim : dimension of given input data\n", " '''\n", " # MNIST data is 28x28 images\n", " def __init__(self, input_dim=28*28) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, 1, bias=True)\n", "\n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " return self.linear(x.float().view(-1, 28*28))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "'''\n", "Step 3: Create the model, specify loss function and optimizer.\n", "'''\n", "model = LR() # Define a Neural Network Model\n", "\n", "def logistic_loss(output, target):\n", " return -torch.nn.functional.logsigmoid(target*output)\n", "\n", "loss_function = logistic_loss # Specify loss function\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) # specify SGD with learning rate" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time ellapsed in training is: 0.2060258388519287\n" ] } ], "source": [ "'''\n", "Step 4: Train model with SGD\n", "'''\n", "import time\n", "start = time.time()\n", "for _ in range(1000) :\n", " # Sample a random data for training\n", " ind = randint(0, len(train_set.data)-1)\n", " image, label = train_set.data[ind], train_set.targets[ind]\n", "\n", " # Clear previously computed gradient\n", " optimizer.zero_grad()\n", "\n", " # then compute gradient with forward and backward passes\n", " train_loss = loss_function(model(image), label.float())\n", " train_loss.backward()\n", "\n", " #(This syntax will make more sense once we learn about minibatches)\n", "\n", " # perform SGD step (parameter update)\n", " optimizer.step()\n", " \n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end-start}\")" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[Test set] Average loss: 5.8580, Accuracy: 1885/1991 (94.68%)\n", "\n" ] } ], "source": [ "'''\n", "Step 5: Test model (Evaluate the accuracy)\n", "'''\n", "test_loss, correct = 0, 0\n", "misclassified_ind = []\n", "correct_ind = []\n", "\n", "# Evaluate accuracy using test data\n", "for ind in range(len(test_set.data)) :\n", "\n", " image, label = test_set.data[ind], test_set.targets[ind]\n", "\n", " # evaluate model\n", " output = model(image)\n", "\n", " # Calculate cumulative loss\n", " test_loss += loss_function(output, label.float()).item()\n", "\n", " # Make a prediction\n", " if output.item() * label.item() >= 0 :\n", " correct += 1\n", " correct_ind += [ind]\n", " else:\n", " misclassified_ind += [ind]\n", "\n", "# Print out the results\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_set.data), correct, len(test_set.data),\n", " 100. * correct / len(test_set.data)))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "'''\n", "Step 6: Show some incorrectly classified images and some correctly classified ones\n", "''' \n", "\n", "# Misclassified images\n", "shuffle(misclassified_ind)\n", "fig = plt.figure(1, figsize=(15, 6))\n", "fig.suptitle('Misclassified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[misclassified_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[misclassified_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_2))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_1))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()\n", "\n", "# Correctly classified images\n", "shuffle(correct_ind)\n", "fig = plt.figure(2, figsize=(15, 6))\n", "fig.suptitle('Correctly-classified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[correct_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[correct_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_1))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_2))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Inspect model parameters" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear(in_features=784, out_features=1, bias=True)\n", "torch.Size([1, 784])\n", "Parameter containing:\n", "tensor([-0.0199], requires_grad=True)\n" ] } ], "source": [ "# model parameters visible as an iterator (generator is an iterator)\n", "# print(model.parameters())\n", "\n", "# for parameter in model.parameters():\n", "# print(parameter.shape)\n", "\n", "# # model parameter directly obtainable from layer\n", "print(model.linear)\n", "print(model.linear.weight.shape)\n", "print(model.linear.bias) #available if bias=True\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVa0lEQVR4nO2dW2xd9ZXGv2WbBHJ14iROYkJCShAkIAgyaKSMRhlVUwESCn3oqDxUjIQmfShSK/VhEPNQHtFo2qoPo0rpgJqOOqBKLSoPaKYRQooqJMCEkAuBxBNCLnbu9wuJk6x58EFywfv73HPsc077/36SZfss//f+n733533O+f5rrchMGGP++ulo9QSMMc3BYjemECx2YwrBYjemECx2Ywqhq5k7mz59es6YMaMy3tHB//d0dnZWxq5fv07Hqm2r+M2bNytjEVH3WIA/L0A/t66u6tM4MjJCx95yyy00rlBuzo0bN+oey54XoI8rO6dq32rbUzl3NZad00uXLuHq1avjXpANiT0iHgXwMwCdAP4zM19kfz9jxgysX7++Mj5r1iy6PxY/efJk3WMBYObMmTR++fLlypgS69WrV2l8zpw5NH7ixAkanz9/fmXs2LFjdGxfXx+NK9Q/k1OnTlXG2D8CAFi4cCGNX7hwgcbZjUXtm51vQD9vdk4A4PPPP697LDunW7ZsqYzV/TI+IjoB/AeAxwCsBvBURKyud3vGmKmlkffsjwAYzMz9mXkNwKsANkzOtIwxk00jYu8DcGjM74drj/0JEbExIgYiYkC9nDXGTB2NiH28DwG+8qlFZm7KzP7M7J8+fXoDuzPGNEIjYj8MYNmY328HMNTYdIwxU0UjYn8PwKqIuDMipgH4NoDXJ2daxpjJpm7rLTOvR8SzAP4Xo9bby5m5m42JCOohKr+axZV1puwxZbUw6059FqHevly7do3GlQXFfPhbb72Vjj1//jyNq3Oijhs7L43aWyrOjouyO9XzVlbuoUOHaHzJkiWVMfW8brvttsoYW1vQkM+emW8AeKORbRhjmoOXyxpTCBa7MYVgsRtTCBa7MYVgsRtTCBa7MYXQ1Hz2rq4u9Pb2VsaHhvgCvLlz51bGzp07R8cybxLQPjxLK2TzArSPPm3aNBpXudWXLl2qjCmfXXndKs+/u7ubxpmPr465SmFV6w9Yeq26XpTPro6rOqdsbmosO+bsfPnObkwhWOzGFILFbkwhWOzGFILFbkwhWOzGFEJTrbeRkREcPXq0Mq5K6LIKsvPmzZP7ZqhyzcwOuXjxIh2rqoWq6rHKomJljVWpaGVJqriyBZn1psoxr1ixgsbVOV20aFFlTFWXVcf8yJEjNK7OOZu7sjvVtVq53bpGGWP+4rDYjSkEi92YQrDYjSkEi92YQrDYjSkEi92YQmiqz65oxDc9ffo0Hcu6ZgLA7NmzaZylW6qywqrDrEqBVaWo2RoDdVzYMQW0n6zWCDCfX6Woqrkrr7yR7rZq/YDatyrRza4Zda2qtQ9V+M5uTCFY7MYUgsVuTCFY7MYUgsVuTCFY7MYUgsVuTCE03Wdn/iVrYwtwb/PKlSt0LCthDQBnz56lcYZqF61aOquyxarsMctvVmsXWH0BQK8B6Ovro3Fa2ljkbe/bt4/G1Tlj6w/U2gjlZauc8kZ8ejWWPW+23YbEHhEHAFwAcAPA9czsb2R7xpipYzLu7H+fmXyJmDGm5fg9uzGF0KjYE8AfIuL9iNg43h9ExMaIGIiIAfXe1RgzdTT6Mn5dZg5FxCIAWyLi48zcOvYPMnMTgE0AMG/ePF5h0BgzZTR0Z8/Modr34wBeA/DIZEzKGDP51C32iJgZEbO/+BnANwDsmqyJGWMml0ZexvcCeK3mEXcB+O/M/B81iHmrjdSNX7ZsGR2rarur/GOWl63y1VlLZUDXKFftgdlzV3nXKq97w4YNNK688uHh4cqYqhuv1l2o9Qcs137v3r10rGqbrNZtqFbWzKdXNQYOHjxYGZsSnz0z9wN4oN7xxpjmYuvNmEKw2I0pBIvdmEKw2I0pBIvdmEJoaoprRDRkYbHSwKrssLLelIWk0lAZ999/P43PmDGDxpXN09PTUxlTLZtVCqs6bsqaO3ToUGVMLZ/ev38/jSvr7tVXX62MqWOurkVlCyobmR1XdS0uWLCgMsbOt+/sxhSCxW5MIVjsxhSCxW5MIVjsxhSCxW5MIVjsxhRCU332jo4OWqJXpYKydEzV5laliarSwKdOnaqMPfzww3SsSr9VpaiVn8xKJiufXbVcfuyxx2icefwAP24qzZSlcgLaK2etrlesWEHHqhbe6nmr63HOnDmVMdYeHODnm60H8Z3dmEKw2I0pBIvdmEKw2I0pBIvdmEKw2I0pBIvdmEJoqs9+8+ZN6iEuXLiw7m2rvOtGWugC3HdVbYuVx688WRVn+c8qJ1zlyn/00Uc0/u6779L4O++8UxlTz0u1ZGb1DQC+xuDee++lYz/88EMaVy2dlVfOrje1NkLlu1eOq2uUMeYvDovdmEKw2I0pBIvdmEKw2I0pBIvdmEKw2I0phKb67F1dXdRLV94k86uVz668y7lz59I4y38eGRmhY99++20aV9x11100vnXr1srY8ePH6djly5fTuKqfrtYnMJ//ypUrdOzQ0BCNP/HEEzTOahSo9QeLFy+mcVVvX9XEZz79mTNn6Fjms7P1JPLOHhEvR8TxiNg15rH5EbElIvbVvldn0xtj2oKJvIz/JYBHv/TYcwDezMxVAN6s/W6MaWOk2DNzK4Av91baAGBz7efNAJ6c3GkZYyabej+g683MYQCofV9U9YcRsTEiBiJiQL1HM8ZMHVP+aXxmbsrM/szsV8kDxpipo16xH4uIJQBQ+84/8jXGtJx6xf46gKdrPz8N4PeTMx1jzFQhffaIeAXAegALIuIwgB8BeBHAbyLiGQAHAXxrIjvLTOpPKi+c1cSeNm0aHau2rfpps/G7d++mY1X983Xr1tH4Bx98QOOdnZ2VMdYfHdC121XOuKppv23btsrYokWVH/UAANauXUvjyutmayPU50eqj4Bav6COG1szwurCAzzPn617kGLPzKcqQl9XY40x7YOXyxpTCBa7MYVgsRtTCBa7MYVgsRtTCE0vJc1S/1S6JENZHcpKWbBgAY0fPXq0MqbSQFVLZlWuWbVVZsdNlWtWbY/vuOMOGlflnpcuXVoZU8dcWWurVq2qe/zg4CAdq0qPq7Lnqv04s3LVMWXpt2y7vrMbUwgWuzGFYLEbUwgWuzGFYLEbUwgWuzGFYLEbUwhN99lZyeclS5bQ8cxvVuWce3p6aPzjjz+mcZZCq8oOqzRTlQKrymSfP3++MqbSSFkaKACcO3eOxlU65kMPPVQZU2nFau3Ep59+SuNs/YFKzWVpw2rbgG7Tzcqmq3TtelNcfWc3phAsdmMKwWI3phAsdmMKwWI3phAsdmMKwWI3phCa6rNHBPUQWc44APT29lbGlDepWvSqnHGWe818bkDnRrMS2YBuJ/3ZZ59VxlTJ5JUrV9K4Oq6sfTDAfd/u7m469vLlyzT+1ltv0TjbvqpvoI65qhOg1oyw56auBxZvqGWzMeavA4vdmEKw2I0pBIvdmEKw2I0pBIvdmEKw2I0phKb67AD3AVV+M8tZV/nJp0+fpnG176GhocqY8qKVD688XVVHnHnlyuNnedUA8Mknn9D4mjVraJydFzW35cuX0/iBAwdonNWNV3n8qsW3mru6nubMmVMZU30Gzpw5UxljHry8s0fEyxFxPCJ2jXnshYg4EhHba1+Pq+0YY1rLRF7G/xLAo+M8/tPMfLD29cbkTssYM9lIsWfmVgD8NbAxpu1p5AO6ZyNiR+1lfmUhsojYGBEDETHA+rwZY6aWesX+cwBfA/AggGEAP676w8zclJn9mdk/ffr0OndnjGmUusSemccy80Zm3gTwCwCPTO60jDGTTV1ij4ix+XvfBLCr6m+NMe2B9Nkj4hUA6wEsiIjDAH4EYH1EPAggARwA8N2J7CwzaR7xrFmz6HhWP135mn19fTR+/PhxGmde+t69e+lY5enec889NK5qt999992VMebJAsCxY8doXHn8O3bsoHFWV1756CqfXZ1TtkZA1btXXrcar65H9vmVWhPCPHpW716KPTOfGufhl9Q4Y0x74eWyxhSCxW5MIVjsxhSCxW5MIVjsxhRCU1NcOzo6qL2m7A5mQc2YMYOOVSWVlUXFrBbV3le1XL506RKNsxLaAG/L3Ki1pkomK9uQndPBwUE6VqUGr127lsZZuqcq16ysM9WSWZVFZ6jz7ZbNxhiKxW5MIVjsxhSCxW5MIVjsxhSCxW5MIVjsxhRCU332zs5OWjZZpTQyb1S14FVVclavXk3jPT09lTFVSpp5n4D2dE+dOkXj7LgdPHiQjt2zZw+NL168mMZVSWW2hkD5yarEtkpxZWmkjYwFdAnu22+/ve7tqzUfrEQ2K93tO7sxhWCxG1MIFrsxhWCxG1MIFrsxhWCxG1MIFrsxhdBUn31kZIS2PlZtcpnPrrxslbetPN0TJ05UxpQXrVBzU6Wkh4eHK2PKo1frC9QagI4Ofr9g+fRq23feeWdD+2YlvNW+1bqN2bNn07iqI8DqLyxYsICOrTdX3nd2YwrBYjemECx2YwrBYjemECx2YwrBYjemECx2YwqhqT57RNAa6yovnNUgZ3m8gK7NztrgAjz/WK0PYD74RGB14QFe213VdX/ggQdoXLWyVu2q2XlRfQJUzrhq8c3WXqhzovLZlc+urgk2N+XRqxoCVcg7e0Qsi4i3ImJPROyOiO/XHp8fEVsiYl/tO29YbYxpKRN5GX8dwA8z814AfwPgexGxGsBzAN7MzFUA3qz9boxpU6TYM3M4M7fVfr4AYA+APgAbAGyu/dlmAE9O0RyNMZPAn/UBXUSsALAWwDsAejNzGBj9hwBg3DeWEbExIgYiYkC9DzLGTB0TFntEzALwWwA/yEzecW8MmbkpM/szs18VfTTGTB0TEntE3IJRof86M39Xe/hYRCypxZcA4B/bGmNairTeYjSv9CUAezLzJ2NCrwN4GsCLte+/n8C2qCXR3d2tNlGJstZUS2dVxprNTaWRqrLCyjZUKa5sbmvWrKFjVbqkKkWtXq2xks3qealzoixJ1lZZ2byNlvdWac8shVZZisy+ZvqaiM++DsB3AOyMiO21x57HqMh/ExHPADgI4FsT2JYxpkVIsWfmHwFUVY34+uROxxgzVXi5rDGFYLEbUwgWuzGFYLEbUwgWuzGF0NQU146ODprWePjwYTqeldhVvqny2U+ePEnjLKXxjjvuoGOVj65KIisvfOXKlZUxlUa6b98+Gl+6dCmN79y5k8aZ3zw4OEjH3nfffTSufHpWelxdD+p6GhkZoXGVGsy2r+bGUn9Z6qzv7MYUgsVuTCFY7MYUgsVuTCFY7MYUgsVuTCFY7MYUQlN99ps3b9ISvSrvm+WsK8+V5QAD2ttkraZZ3jQA9Pb20rgqmazKXLOcclWWWMXVGgGVt71w4cLKmMoZV8dFjT9z5kxlTF1rqlyzWn+g1m1cvHixMqae15UrVypjbN6+sxtTCBa7MYVgsRtTCBa7MYVgsRtTCBa7MYVgsRtTCE312YH6280CvDUxy3UHeJ1uQLfoZdtvJP8Y0F62em6sVriq665yxpVfPG8eb97L8rZ7enroWJXnr+ofsOOqzjerdw/ouvFq7QXz0s+ePUvHsrrybD2J7+zGFILFbkwhWOzGFILFbkwhWOzGFILFbkwhWOzGFMJE+rMvA/ArAIsB3ASwKTN/FhEvAPhnACdqf/p8Zr7BtpWZ1O8+cuQInyzxJlm9bED3+lZ+sfJVGconV7nRbH0BwL1sNVZ53czDB7SPz2rasxoBAM/bBvTaiWvXrlXGVF14VatfXU/KK2e95dXcWH0D5rNPZFHNdQA/zMxtETEbwPsRsaUW+2lm/vsEtmGMaTET6c8+DGC49vOFiNgDgC8vMsa0HX/We/aIWAFgLYB3ag89GxE7IuLliBj3dXBEbIyIgYgYUEsUjTFTx4TFHhGzAPwWwA8y8zyAnwP4GoAHMXrn//F44zJzU2b2Z2a/en9njJk6JiT2iLgFo0L/dWb+DgAy81hm3sjMmwB+AeCRqZumMaZRpNhjtBXmSwD2ZOZPxjy+ZMyffRPArsmfnjFmspjIp/HrAHwHwM6I2F577HkAT0XEgwASwAEA35U76+qiNpRqg8veBiirZP78+TTOylQD3A5RZay7u7tp/MCBAzSu0oKZLahKaM+dO5fGVXquGs/ss+HhYTpWnRPVjprZhmrfyopV+1bXMjunzDIEeDtott+JfBr/RwDjNbqmnroxpr3wCjpjCsFiN6YQLHZjCsFiN6YQLHZjCsFiN6YQmt6ymaUGqvK7zM9WaaQq5VD5oqz1MGu/C+iyxay1MKD9ZoYqU63iKnV47969NM7WJ6jWxOfPn6dxNX50Pdj4qDbazMsG9HFRcXZc1JoQlhrMzqfv7MYUgsVuTCFY7MYUgsVuTCFY7MYUgsVuTCFY7MYUQiifdVJ3FnECwGdjHloAgPcEbh3tOrd2nRfgudXLZM5teWaOuyikqWL/ys4jBjKzv2UTILTr3Np1XoDnVi/NmptfxhtTCBa7MYXQarFvavH+Ge06t3adF+C51UtT5tbS9+zGmObR6ju7MaZJWOzGFEJLxB4Rj0bEJxExGBHPtWIOVUTEgYjYGRHbI2KgxXN5OSKOR8SuMY/Nj4gtEbGv9p0XOG/u3F6IiCO1Y7c9Ih5v0dyWRcRbEbEnInZHxPdrj7f02JF5NeW4Nf09e0R0AtgL4B8AHAbwHoCnMvOjpk6kgog4AKA/M1u+ACMi/g7ARQC/ysz7ao/9G4DTmfli7R/lvMz8lzaZ2wsALra6jXetW9GSsW3GATwJ4J/QwmNH5vWPaMJxa8Wd/REAg5m5PzOvAXgVwIYWzKPtycytAE5/6eENADbXft6M0Yul6VTMrS3IzOHM3Fb7+QKAL9qMt/TYkXk1hVaIvQ/AoTG/H0Z79XtPAH+IiPcjYmOrJzMOvZk5DIxePAAWtXg+X0a28W4mX2oz3jbHrp72543SCrGPVxisnfy/dZn5EIDHAHyv9nLVTIwJtfFuFuO0GW8L6m1/3iitEPthAMvG/H47gOoKek0mM4dq348DeA3t14r62BcddGvfeWXEJtJObbzHazOONjh2rWx/3gqxvwdgVUTcGRHTAHwbwOstmMdXiIiZtQ9OEBEzAXwD7deK+nUAT9d+fhrA71s4lz+hXdp4V7UZR4uPXcvbn2dm078API7RT+T/D8C/tmIOFfNaCeDD2tfuVs8NwCsYfVk3gtFXRM8A6AHwJoB9te/z22hu/wVgJ4AdGBXWkhbN7W8x+tZwB4Dtta/HW33syLyacty8XNaYQvAKOmMKwWI3phAsdmMKwWI3phAsdmMKwWI3phAsdmMK4f8BjxtHIqL640QAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# view the weight trained through logistic regression\n", "weight = model.linear.weight.detach().reshape(28,28).numpy()\n", "weight = weight - np.min(weight)\n", "weight = weight/np.max(weight)\n", "plt.imshow(weight, cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DataLoader utility simplifies accessing data in training and testing" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time ellapsed in training is: 0.3480958938598633\n", "[Test set] Average loss: 0.1989, Accuracy: 1850/1991 (92.92%)\n", "\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2cAAAFoCAYAAADTgoOZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAArC0lEQVR4nO3deZhcZZn38d8PIkHWhAlggBCYICgGBhkWB1B4XxgRBmQZIyIgOEhEI0Z0JMK8soyAbIIyIpLIPqjIIrsvCSEiMMgWEIKRLVeAQBJIiEDYSe7545yWoumup7rrVNfTne/nuvrqrnPues5d1ak7fdfz1DmOCAEAAAAA2mu5dicAAAAAAKA5AwAAAIAs0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQC0ge1DbEf5tXEX+3eq2b9LzfaLbM+uOJcNyuMcUuW4DR6743HuVLNtOds/tj3X9lLb17Qix3K84xvMr6uvr3SVPwAAvTWo3QkAwDLuFUkHSfp+p+1fKvet2mn7DyT9pA/y6ivTJf2TpD/XbPucpPGSviPpLkkLJc0t457s6wRL35R0b6dtT0p6U+/PHwCAXqE5A4D2ulrSgbaPjYiQJNsflPSvkq6SdEhtcES0qzlpiYh4WdIfO23+aPn9xxGxtGZ757i+NDMiujt+n+dle3BEvNnXxwUAtBbLGgGgvS6VNFLSDjXb9pG0vIrm7D06L2u0Pcj2D2w/afsN2wts32F7h073O8z2dNuv215k+zbb23WXlO2tbV9pe055n0dtn1w2jrVxu9q+0/ZLtheXccfW7N/Y9m9tP1/m97TtK2wPKve/Z1lg+diOL+++pGMpY3fLGm3vaHuq7Vdsv2r7ZtujO8Usb/vEcpnka7Z/b/tj3T32nuhmWWbn491q+yOdl1F2t0S1zO/3XRxjX9uTbL8gaX7N/sNs/6nm93++7TU6jTne9sya3/99tvep4jkAAFSHmTMAaK+nJP1BxdLG28ttX5L0W0mLG7j/BElHSvoPSQ9KWk3SVpL+9se57TNULBE8X9JxkpZK+oSk9SX9Tzfjrl+Od5GK5ZUfk3SspL+X9IVy3L+XdJ2kK1Ust3xL0ofLmA43SPqrpK9JWiBpXUm7q/s3B/dRsYTwEBXLBaVi+eDKnQNt/4ukayXdKOnAmufjdtubR8Qz5bbjJR0j6UxJk1U8P9d1c/zuLNfRUJYiIpZ0E3tCebzTJd0iacteHK8r/yXpdyr+rawoSbZPUfG7PVvSd1U8vydKGm17u4hYYvsAST+S9J8q/o19UNLmqvk3AgDIA80ZALTfJZJ+ZPubkoZK2kXSbg3e958kTY6I2s+hXd/xg+2NVDRvZ0XEt2tibqw3aET8bdbOtiXdKellSZfYHhcRC1U0HStI+lq5PFGSbq253zAVzdpeEVHbnPyyznEfsP1s+fPflgvafl9zpuKzd7dFxF41cdMkzVLRsHzL9tDy8U+MiH8vwybbXiLplHrPQSc3d7r9rKT1OgeVx/uWpJ9HxIRy8xTbb6tokJpxT0R8peZYG6hoyE6IiP+s2f6YpDsk7SnpGhX/Rh6qjZF0U5O5AABagGWNANB+V0garOKP6QMkzZM0tcH73itpd9sn2d7B9gqd9u+iotZP7ElCtlezfartjpNevK1iCaZVNFxSMbP2tqRf2/6c7bU6DbNQRaN0Srn07sOqSDnWKEmXlUs7B5UzW6+pOInIp8rQzVTMuv2m0xC/7uEhx0nauuZr927iOo53RaftV/bweF35bafb/6zid9v5ObhbRSPd8RzcK2kL2/9lexfbK1WQCwCgBWjOAKDNIuIVFTMcB6lY0nhZpxNh1HOyiqWKn1WxZG2h7QvLWStJ+rvy+5wepnWhpMNVLJf7ZxUNybhy34pl3k9I2lXF/yWXSppn+27bO5b7o7zvfZJ+KOkx27Nsf62HuXSloxE8X0WDWPu1h9593MPL7/P1Xp1vpzwWEffVfD3UTVzH8Z5v8nhdmdvpdsdz8ITe/xyspnefg0tULCvdVsUM4Iu2ry5n3gAAGWFZIwDk4RIVSw2Xk7R/o3eKiLclnSrpVNsfUtGYnClpJUn7qficl1R8FunRRsa0vaKkvSQdX7tc0vZmXRx/mqRptgdL2l7F55putL1BRCyIiFmSvlQujfwHSd+Q9DPbsyPid40+zi4sLL8freJzXZ29VX7vaGjWlvRIzf61mzh2PR3HW6uB472hYlloZ3+ndx9freh0uyPm05IWdRG/UPpbk3yepPPKZZefVrHE8nIVDRsAIBM0ZwCQhykqlt79NSIeSQV3JSLmSfqF7d0ldZyx8BYVJwAZq+JzWI0YrOJskW932n5InWO/KelW26uoOEnHhnq3MexoEB60/W1Jh5b5NdOcPSpptqSPRUS9z449JOlVSZ9XzefhVJ7UpAUeLo83RtK0mu1juoh9StLatodFxAJJsj1K0ibq/kQttaao+N2uHxFTGkkuIhZJutz2tpK+2sh9AAB9h+YMADJQnvmv4RmzDravlfQnFRdzXiTp45I+o2KmRBHxpO2zJH3b9qoqzhq4RNI2kv4SEZd3kctLtv8o6Tu256posv5Nxexb7bEPV/G5ppskPSNpmIqZrOckzbC9uYqTdlyuYund8ioavHf03kapxyIibI+TdG35ObvflHmuLWk7SU9HxJkR8dfy8f+H7VdUnK1xaxUNYuUiYpHtH0s6pjxex9kaO45Xu1z1ChVnubzM9pl69/lboAaUv9tTJf3U9iaSblMxGzdCxXLSX0TENNsTVZxx8y4Vyy03VrGEdnIzjxUAUD2aMwDo3/6gYlZmnIqljE9LOk3SSR0BEfHvtp+Q9HVJB6uY2XlI9f8431/SuZLOkfS6iuZnvIpT43f4k4qzSv5QxTK+F1WcJfCAiHjd9rwyn2+rOLPhGypmlvaIiPubetTF47rJ9qdUXEbgFypOET9PxUWha5vO41WcyOQrKpZV3q3i5Cu9mqFswHHl8Q5VcVmAu1U0pXdKeqkm/ydsf07Fqe+vkfSYiufqmEYPFBHH2J6p4vc/TsXSx2dUnFDm8TLsTklfVtGQra6ief7vMk8AQEZcrDQBAACtYnuMigb3UxFxeyoeALBsojkDAKBC5ee5/kXFjNkbkv5R0vdUfE5uu+A/XgBAN1jWCABAtRar+CzeOBWntH9exazZ0TRmAIB6mDkDAAAAgAxwEWoAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5Q5+xvZPtOX19XwBIoT4ByBG1adlDc9aP2V5c87XU9us1tw9o4XEPsX1Hq8avgu3tbN9j+xXbD9neod05AcsS6lP3qE9A+1CbukdtygPNWT8WEat0fEl6WtKeNdsu64izPah9WfY922tIuk7S6ZKGSDpN0vW2h7YzL2BZQn3qGvUJaC9qU9eoTfmgORuAOqaxbU+wPU/ShV29Y2M7bG9U/jzY9hm2n7Y93/bPbX+wF8f+su2Z5bsus2x/tYuYY2wvsD279l2qqnKQtJ2k+RFxRUQsiYj/lvSCpH17MRaAClGfqE9AjqhN1KZc0JwNXB+StIakkZLGNhB/qqSNJW0haSNJ60o6thfHfV7SHpJWk/RlSWfZ3rJTXsPK8Q+WNNH2Jj3NwfbPbP+smxxcfnXeNrqnDwZAS1Cf3r+N+gS0H7Xp/duoTX2M5mzgWirpuIh4MyJerxdo25IOk3RkRLwYEa9IOlnSF3p60Ii4MSKejMJtkiZL+mSnsO+Xed0m6UZJn+9pDhHx9Yj4ejdp/I+kdWzvb/sDtg+WNErSSj19PABagvpEfQJyRG2iNrXdMrWedhnzQkS80WDsmipefPcXr3NJxbsly/f0oLZ3k3ScindxlivHfbgmZFFEvFpz+ylJ61SZQ0QstL2XpDMknSPpZkm3SOKMRUAeqE/UJyBH1CZqU9vRnA1c0en2q6p598P2h2r2LZD0uqSPRcSzvT2g7cGSrpL0JUnXRsTbtq/Re6fJh9peuabIrC9pRlU5dCjfWdq6zGuQpCcl/ajZcQFUgvpEfQJyRG2iNrUdyxqXHX+S9DHbW9heUdLxHTsiYqmkSSrWOK8lSbbXtb1rnfFse8XaL0krSBqs4gOk75TvBH26i/ueYHsF259Uscb6il7mUC+5j5fT8qupeBdoTkTc3JuxALQc9Yn6BOSI2kRt6nM0Z8uIiHhM0n+qmKJ+XFLna21MkPSEpD/afrmM20Td207FuzWdv74p6TeSFkn6oorTstaaV+57TtJlkg6PiL/0NAcXZyP6eZ38jlLxjtIzkoZL2qdOLIA2oj5Rn4AcUZuoTe3giM4zuAAAAACAvsbMGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwhyfZFtk8sf/6k7Ud7Oc7PbX+/2uwALKuoTQByRG1CM2jOBgjbs22/bnux7fm2L7S9StXHiYjbI6LeaWI78jnE9ntOORsRh0fED6rOKZHHrbajvJgigD5Gbeo2D2oT0EbUpm7zoDa1Gc3ZwLJnRKwiaUsVV3j/f50DlqUXm+0DJC0zjxfIGLWpBrUJyAa1qQa1KQ80ZwNQRDwr6XeSRktS+Q7IONuPq7iIomzvYftB23+1/T+2N++4f3mF+Om2X7F9uaQVa/btZHtOze0Rtq+2/YLthbZ/avujkn4u6Z/Kd6T+Wsb+bZq/vH2Y7Sdsv2j7Otvr1OwL24fbftz2Itvn2Hajz4Ht1SUdp+KCigAyQG2iNgE5ojZRm3JCczYA2R4haXdJD9Rs3lvStpI2tb2lpAskfVXS30k6T9J1tgfbXkHSNZIulbSGpCsk/Ws3x1le0g2SnpK0gaR1Jf06ImZKOlzSXRGxSkQM6eK+/1fSDyV9XsVV6J+S9OtOYXuoeCfrH8q4Xcv7rl8Wx/XrPA0nSzpX0rw6MQD6ELVJErUJyA61SRK1KRs0ZwPLNeW7LXdIuk3FC63DDyPixYh4XdJhks6LiLsjYklEXCzpTUmfKL8+IOnHEfF2RFwp6d5ujreNpHUkfTciXo2INyLijm5iOztA0gURMT0i3pR0tIp3jDaoiTklIv4aEU9LmiZpC0mKiKcjYki5/X1sbyVpe0n/1WAuAFqL2iRqE5AhapOoTblhXenAsndE3NLNvmdqfh4p6WDbR9RsW0FFwQhJz0ZE1Ox7qpsxR0h6KiLe6UWu60ia3nEjIhbbXqjiXaTZ5ebad29ek5T8oK7t5ST9TNL4iHinBzP6AFqH2kRtAnJEbaI2ZYeZs2VHbdF4RtJJ5bsoHV8rRcSvJM2VtG6ndcrdTYM/I2l9d/1h2ehiW63nVBQ7SZLtlVUsFXg29UASVpO0laTLbc/Tu+9ezbH9ySbHBlA9ahO1CcgRtYna1BY0Z8umSZIOt72tCyvb/hfbq0q6S9I7kr5pe5DtfVVMw3flHhVF6ZRyjBVtb1/umy9pvXItdld+KenLtrewPVjFUoK7I2J2k4/tJRXvLm1Rfu1ebv9HSXc3OTaA1qI2AcgRtQl9huZsGRQR96lYP/1TSYskPSHpkHLfW5L2LW8vkrSfpKu7GWeJpD0lbSTpaUlzynhJulXSI5Lm2V7QxX2nSvq+pKtUFKpRkr7QSP7lB1sXd/XB1ijM6/iS9EK5a3752ABkitoEIEfUJvQlv3eJLAAAAACgHZg5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyEBXF8FrGducGhIYgCLC6ah8UZuAAWtBRKzZ7iSaQX0CBqbu/nZqaubM9mdsP2r7Cdvfa2YsAKgS9QmApKfanUBn1CYA9fS6ObO9vKRzJO0maVNJ+9vetKrEAKC3qE8AckRtApDSzMzZNpKeiIhZ5RXEfy1pr2rSAoCmUJ8A5IjaBKCuZpqzdSU9U3N7TrntPWyPtX2f7fuaOBYA9ESyPlGbALQBfzsBqKuZE4J09SG2931oNSImSpoo8aFWAH0mWZ+oTQDagL+dANTVzMzZHEkjam6vJ+m55tIBgEpQnwDkiNoEoK5mmrN7JX3Y9oa2V5D0BUnXVZMWADSF+gQgR9QmAHX1elljRLxj+xuSbpa0vKQLIuKRyjIDgF6iPgHIEbUJQIoj+m4pM+umgYGJi1ADyNT9EbFVu5NoBvUJGJhachFqAAAAAEA1aM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGBrU7AfStNdZYIxkzbdq0ZMxmm21Wd/8999yTHOOKK65IxkydOjUZ8+CDDyZjAPR/Y8aMScaMHTu27v6dd965qnSS5s6dm4xpJJ+//OUvVaQDoM1WX331uvvHjx+fHKOROjhq1KhkzEknnVRJDKrHzBkAAAAAZKCpmTPbsyW9ImmJpHciYqsqkgKAZlGfAOSI2gSgniqWNf6fiFhQwTgAUDXqE4AcUZsAdIlljQAAAACQgWabs5A02fb9trv8FLbtsbbvs31fk8cCgJ6oW5+oTQDahL+dAHSr2WWN20fEc7bXkjTF9l8i4g+1ARExUdJESbIdTR4PABpVtz5RmwC0CX87AehWUzNnEfFc+f15Sb+VtE0VSQFAs6hPAHJEbQJQT6+bM9sr216142dJn5Y0o6rEAKC3qE8AckRtApDSzLLGtSX91nbHOL+MiP9fSVZomS9+8YvJmNGjRydjXnjhhbr777333uQYm2++eTJmypQpyRigC9SnjHzwgx9MxkycODEZ87nPfS4ZM3jw4Lr7Z8xI/x187bXXJmOGDBmSjBk3blwy5u67707G7L777nX333nnnckxkA1q0wC16aabJmNuvvnmuvvXW2+9qtJJOvHEE5Mxb775ZjLmjDPOqCId1Oh1cxYRsyT9Q4W5AEAlqE8AckRtApDCqfQBAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADDRzEWr0Q8OGDatknKOOOqru/osvvriS4wDIW+qiz5J0/vnnJ2P222+/ZMz999+fjDn55JPr7r/hhhuSY7zzzjvJmEZMmjQpGXPllVcmYy677LK6+3faaafkGLNnz07GAOhaIxedb6TOpS4yffrppyfHOPbYY5Mxe++9dzLm3HPPTcaMHDkyGYPqMXMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAywEWoB5AVV1wxGdPIhQkbsWDBgkrGAdC/nXDCCcmYqi4wvdtuuyVjFi5cmIzpKw899FAyZpdddknGTJ8+ve7+XXfdNTnGeeedl4wB0LWzzjorGfOJT3wiGXPEEUfU3d/IhaGXLFmSjPn973+fjBk/fnwy5pRTTknGpB4Teo6ZMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADXOdsADn00EOTMZtttlky5uyzz07G3HTTTQ3lBGBgmzBhQjLm5ZdfTsYcffTRyZicrmFWlaeffjoZM2fOnLr7G7k20vXXX5+Mee6555IxwEDzta99LRlz8MEHJ2MuuuiiZEzqtdrINczGjBmTjDnnnHOSMfvuu28yppG6geoxcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADLARagHkJVWWikZYzsZ89hjjyVjIqKhnPrCoEHpf8bLLZd+H+Ktt96qIh1gQBk7dmzd/Y3UgiuuuCIZM3Xq1IZzGkh22GGHZMxHP/rRuvsb+R3svPPOyZhLL700GQMMNPvvv38y5plnnknGTJgwIRnTyEWmUzbaaKNkzJprrpmM2W+//ZIxkydPbignVIuZMwAAAADIQLI5s32B7edtz6jZtobtKbYfL78PbW2aAPB+1CcAOaI2AeitRmbOLpL0mU7bvidpakR8WNLU8jYA9LWLRH0CkJ+LRG0C0AvJ5iwi/iDpxU6b95J0cfnzxZL2rjYtAEijPgHIEbUJQG/19oQga0fEXEmKiLm21+ou0PZYSfU/UQ4A1WmoPlGbAPQx/nYCkNTyszVGxERJEyXJdj6n+AOwTKM2AcgV9QlYdvX2bI3zbQ+XpPL789WlBABNoT4ByBG1CUBSb5uz6yQdXP58sKRrq0kHAJpGfQKQI2oTgKTkskbbv5K0k6RhtudIOk7SKZJ+Y/tQSU9LGtPKJCENHZo+4+5hhx2WjHnzzTeTMddff31DOeWikXxHjhyZjNl0002rSAd9iPrUevvuu2/TY8ycObOCTAamiRMnJmMGDar/X/Vrr72WHOOGG25oOCc0j9qUh8MPPzwZs/322ydjjj/++GTM88/3zUTolltuWck4qboiSVdddVUlx0LPJH8zEdHdpdN3rjgXAOgR6hOAHFGbAPRWb5c1AgAAAAAqRHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMhA+gp0yMJee+2VjBk1alQy5qyzzkrGzJkzp6Gc+sImm2ySjNlmm22SMUOGDKkgGwBo3Jgx6WsMb7jhhk0fZ9y4ccmYRYsWNX0coL854IADkjGNXMT9sssuqyKdpLXWWisZs+OOO1ZyrDfeeKOScVA9Zs4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGuAh1PzF69OhKxrnzzjsrGacK2223XTKmkQs/VnWB6WHDhtXdv2DBgkqOA/QnttudQp8bPHhwMuaEE05Ixhx11FFVpKOXXnqp7v4pU6ZUchygv1l11VXr7m/kb6cZM2YkY2bNmtVwTs046KCDkjFrrrlmH2SCdmLmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgA1znrJ4444ohkzGuvvZaMeeCBB6pIJ2mHHXZIxlx++eXJmA996ENVpNOQVM7XXHNN3yQCZOTFF19seoxG6tfvfve7ZMyf//znZEzqukdf+cpXkmNMmDAhGVPVtYYiIhlz5JFH1t0/d+7cSnIB+ptjjz227v5GroN60003VZRN2kYbbVR3/7hx4/ooE+SMmTMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABLkKdifXXX7/u/hVWWCE5xtVXX52MmT17djJm0KD0P4tTTjml7v7vfOc7yTEaufjqaaedloy57bbbkjE33nhjMgbA+33961+vu/+zn/1scowRI0YkY2699dZkTCMXW15ttdXq7t9ggw2SY/SlF154IRlz880390EmQP+Tuuh8IzbeeONkzDrrrJOM2X777ZMxp556at39qfolSQsWLEjGDBs2LBmDfDFzBgAAAAAZSDZnti+w/bztGTXbjrf9rO0Hy6/dW5smALwf9QlAjqhNAHqrkZmziyR9povtZ0XEFuXXTdWmBQANuUjUJwD5uUjUJgC9kGzOIuIPkl7sg1wAoEeoTwByRG0C0FvNfObsG7YfKqfuh3YXZHus7fts39fEsQCgJ5L1idoEoA342wlAXb1tzs6VNErSFpLmSvpRd4ERMTEitoqIrXp5LADoiYbqE7UJQB/jbycASb1qziJifkQsiYilkiZJ2qbatACgd6hPAHJEbQLQiF41Z7aH19zcR9KM7mIBoC9RnwDkiNoEoBHJqw3b/pWknSQNsz1H0nGSdrK9haSQNFvSV1uX4rJhyy23rLu/kQs2N3KB1iFDhiRjzj777GTMAQccUHd/I/lOnjw5GXPiiScmY7baKr3qo5F80P9Qn1pv0aJFdfdvu+22yTEuvPDCZEyqBkrSmmuumYyxXXf/iy+mz9EwadKkZMxRRx2VjGnEpZdemoxppLYjL9SmvjF9+vS6+5cuXZoc48ADD0zG7LnnnsmY1VdfPRnz9ttv192///77J8do5KLZJ598cjJmxIgRyRi0R7I5i4iu/qWc34JcAKBHqE8AckRtAtBbzZytEQAAAABQEZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGktc5Q/8xcuTIZMxdd92VjGnkAocpP/3pT5MxxxxzTDJm8eLFTecCoHVmzJiRjNl6662TMXvssUcyZujQocmY1AWbb7nlluQYW2yxRTJmwoQJyZhGnHHGGZWMAyyLJk6cWHf/kCFDkmOMGzeuklxuv/32ZMwll1xSd/9VV12VHOPoo49uOKd6Ntlkk0rGQfWYOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABngItSZWLBgQdNj7L333s0nIumJJ55IxkyaNKnu/r68sOpmm21WyTgPP/xwJeMA6LkbbrihT46zzjrrJGPOPPPMZExEJGMef/zxZMzixYuTMQB657TTTqskZiCq4u9OtAYzZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIANchDoTd9xxR939Y8eOTY6x2267JWMeffTRZMzpp5+ejFm0aFEypq9suOGGlYzz5JNPVjIOgHyNGDEiGbPjjjtWcqxx48YlY1599dVKjgUAPbHGGmu0OwV0g5kzAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIANc56yf+MUvflFJzEBku5IYAAPfd7/73UrGmTVrVjJm6tSplRwLAKo2cuTIZMzw4cOTMXPnzq0iHdRg5gwAAAAAMpBszmyPsD3N9kzbj9geX25fw/YU24+X34e2Pl0AKFCbAOSK+gSgtxqZOXtH0nci4qOSPiFpnO1NJX1P0tSI+LCkqeVtAOgr1CYAuaI+AeiVZHMWEXMjYnr58yuSZkpaV9Jeki4uwy6WtHeLcgSA96E2AcgV9QlAb/XohCC2N5D0cUl3S1o7IuZKRRGyvVY39xkraWyTeQJAt6hNAHJFfQLQEw03Z7ZXkXSVpG9FxMuNnv0uIiZKmliOEb1JEgC6Q20CkCvqE4CeauhsjbY/oKK4XBYRV5eb59seXu4fLun51qQIAF2jNgHIFfUJQG80crZGSzpf0syIOLNm13WSDi5/PljStdWnBwBdozYByBX1CUBvNbKscXtJB0l62PaD5bZjJJ0i6Te2D5X0tKQxLckQSIhIr/hoJAb9DrUJ7zN69Oi6+3fbbbfkGI0sPRs1alTDOWGZRH1CS0ybNi0Zs3Tp0mTM6quvnoz5yEc+kozhItTVSzZnEXGHpO7+p9q52nQAoDHUJgC5oj4B6K2GPnMGAAAAAGgtmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAYauQg1AABtN3To0GTMgQceWHf/iiuumByDi9YDyNUf//jHZMxPfvKTZMyRRx6ZjNl3332TMY1cFBs9w8wZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADNGcAAAAAkAGaMwAAAADIABehRr83ZcqUZMz48eP7IBMArbTHHnskY4466qi6+xu5wPTChQuTMQcddFAyBgDaYd68eZWMs8suuyRj1l577br758+fX0kuyxJmzgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAa4CDX6vQceeCAZM2vWrD7IBEDuFi1alIxp5GLX99xzTxXpAEDlJk+enIyZMGFCMua2225LxixdurShnNA4Zs4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGHBF9dzC77w4GoM9EhNudQzOoTcCAdX9EbNXuJJpBfQIGpu7+dkrOnNkeYXua7Zm2H7E9vtx+vO1nbT9Yfu1eddIA0B1qE4BcUZ8A9FZy5sz2cEnDI2K67VUl3S9pb0mfl7Q4Is5o+GC8+wMMSO2YOaM2AWhAW2bOqE8AUrr722lQA3ecK2lu+fMrtmdKWrfa9ACgZ6hNAHJFfQLQWz06IYjtDSR9XNLd5aZv2H7I9gW2h3Zzn7G277N9X3OpAkDXqE0AckV9AtATDZ8QxPYqkm6TdFJEXG17bUkLJIWkH6iYvv+3xBhMzQMDUDtPCEJtAlBHW08IQn0C0J1enxBEkmx/QNJVki6LiKvLAedHxJKIWCppkqRtqkoWABpBbQKQK+oTgN5o5GyNlnS+pJkRcWbN9uE1YftImlF9egDQNWoTgFxRnwD0ViNna9xB0u2SHpa0tNx8jKT9JW2hYmp+tqSvlh+ArTcWU/PAANSmszVSmwCktOtsjdQnAHV197cTF6EG0DQuQg0gU1yEGkCWmvrMGQAAAACgtWjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAwM6uPjLZD0VM3tYeW2/oJ8W4t8W6tV+Y5swZh9rXNtkvj9thr5thb5FgZifeJ321rk21rkW+i2NjkiWnC8xti+LyK2alsCPUS+rUW+rdXf8m23/vZ8kW9rkW9r9bd826m/PVfk21rk21rtyJdljQAAAACQAZozAAAAAMhAu5uziW0+fk+Rb2uRb2v1t3zbrb89X+TbWuTbWv0t33bqb88V+bYW+bZWn+fb1s+cAQAAAAAK7Z45AwAAAACI5gwAAAAAstC25sz2Z2w/avsJ299rVx6Nsj3b9sO2H7R9X7vz6cz2Bbaftz2jZtsatqfYfrz8PrSdOdbqJt/jbT9bPscP2t69nTnWsj3C9jTbM20/Ynt8uT3L57hOvtk+x7mgNlWL2tRa1KZlC/WpWtSn1qE2NZFLOz5zZnt5SY9J+mdJcyTdK2n/iPhznyfTINuzJW0VEVleOM/2pyQtlnRJRIwut50m6cWIOKUs4kMjYkI78+zQTb7HS1ocEWe0M7eu2B4uaXhETLe9qqT7Je0t6RBl+BzXyffzyvQ5zgG1qXrUptaiNi07qE/Voz61DrWp99o1c7aNpCciYlZEvCXp15L2alMuA0JE/EHSi5027yXp4vLni1X8I8tCN/lmKyLmRsT08udXJM2UtK4yfY7r5Iv6qE0Voza1FrVpmUJ9qhj1qXWoTb3XruZsXUnP1Nyeo/yLc0iabPt+22PbnUyD1o6IuVLxj07SWm3OpxHfsP1QOXWfxVR3Z7Y3kPRxSXerHzzHnfKV+sFz3EbUpr6R/eumC9m/bqhNAx71qW9k/9rpQtavHWpTz7SrOXMX23I/p//2EbGlpN0kjSunllGtcyWNkrSFpLmSftTWbLpgexVJV0n6VkS83O58UrrIN/vnuM2oTehK9q8batMygfqErmT92qE29Vy7mrM5kkbU3F5P0nNtyqUhEfFc+f15Sb9Vsbwgd/PLNbQda2mfb3M+dUXE/IhYEhFLJU1SZs+x7Q+oeMFeFhFXl5uzfY67yjf35zgD1Ka+ke3rpiu5v26oTcsM6lPfyPa105WcXzvUpt5pV3N2r6QP297Q9gqSviDpujblkmR75fLDgbK9sqRPS5pR/15ZuE7SweXPB0u6to25JHW8WEv7KKPn2LYlnS9pZkScWbMry+e4u3xzfo4zQW3qG1m+brqT8+uG2rRMoT71jSxfO93J9bVDbWoil3acrVGSXJyK8seSlpd0QUSc1JZEGmD771W84yNJgyT9Mrd8bf9K0k6ShkmaL+k4SddI+o2k9SU9LWlMRGTxQdJu8t1JxbRxSJot6asd65LbzfYOkm6X9LCkpeXmY1SsR87uOa6T7/7K9DnOBbWpWtSm1qI2LVuoT9WiPrUOtamJXNrVnAEAAAAA3tW2i1ADAAAAAN5FcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyMD/AlxEqezYGMLXAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2cAAAFoCAYAAADTgoOZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAuu0lEQVR4nO3deZxcZZX/8e9XwhJIhCBgCAYhiAuLEomI4JIQBQbxhxsqioIigVEHNAoyjNiJDgr5GXFfUDYZXEAgMOAIBBJQWdMZhGBYQwhZgSRoQFQgZ/64t6VoquvervXp7s/79epXd9176j6nbnedrlP31n0cEQIAAAAAdNaLOp0AAAAAAIDmDAAAAACSQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwBQhe032b7Q9nLb/7C92vY1to+wvUGn8yvL9g62p9keV2XdYtvntjGXc20vbtd4vcaea3tur2Vvsn2L7Sdth+09mp1jvu8L56zJ84sqX0v7yh8AMPgM63QCAJAa25+V9E1J10n6oqSHJI2StL+kH0p6XNJlHUqvv3aQ1CXp95IWdTaVjvpUlWVnSXpK0rsk/VXSvZK+Kunbbcyr0h2Sjum17O/592r5AwAGGZozAKhg+63KGrPvRcRxvVZfZvubkjZrwjgbR8TfqyzfUNIzEVF4tAXlRcSfKm/bfpGkV0k6NSKuq1j1QFsTe751EXFztRW982+Xvv5OAQCtwWmNAPB8J0laI+nEaisj4oGIuKPntu29bM+2/UR+ety1tveqvE9+qtzS/DS6G20/JWlGfsph2P6U7Rm2lys7UrJFfr/32r7Z9l9tP277Itvb987J9tG259t+yvZa29fb3sf2RElz8rBrKk6Vm1hlG3vm6w6psq4n/5qnc9re0fb5tlfa/rvtRbZrHoWyPT3P/c+2H7N9ne29e8WMsP1d20vy7a7K9/mrK2KOt72wYh/Ms/2eivX/PC3Q9pGSnlX2P/CU/HEvrnisi3uNv6nt020/mJ/i+qDt/8gbvMq48bZ/Z/tvtpfZPkWSaz3+svo4LfP1FeM9bPvkfH9GRUzP39iRve47sfffQj7G722/y/b/2v678iN2+e/2AtuP5r+D2yv3bx7zStuX2n4kz2lJ/jfLG8EAUBIFEwByefMxUdKsiPhbifjXSrpe0p8kHSkplDV319veOyL+WBG+uaRfSvqGpJOVnU7X4z8k3SZpiqQNJP3N9rHKTqE8R9JXJI2UNC3f9msjYl2ewzckfV7ZKXpdktZL2lvS9pJ+I+nTkr4v6bh8DOX5Pk9EdNu+Tdlpdf88ZdP2FpI+IGlGRDxbY1/sKOlWZacHdkm6T9JYZaeC1rKdpDMkLVV2RPJwSTfYnlDRBJ8h6f8p22/3SXqJpH31XBP7EUkz8/30O0nDJb1W0pZ9jHmlpDcrO9XzLEk/1XOnD/Z+XMMkXSVpF2WnPN6pbP+ekm//83ncVspOg10p6Yh8eyco+z2UVqWRebbaUdR8vGslLZf0MUn/kPQ5ZaexNuKVkr6j7LEukrTG9lhJt0h6JB/jUUkflHSx7XdHxOX5fa9Qdsrvv0p6TNnv9iDxRjAAlEZzBgDP2UrZC/uHSsZ/WdmL8MkR8bgk2b5G0mJlDcp7K2JHSDo8Iiobnx3yH1dJek/Pi3DbIySdLumciPhERfwtyj4XdZSkb9l+hbIXy2dExNSKsa6suE9PI7awr1PmKvxA0lm2Xx4RPfvgY5I2UtbA1DJd2b57XUQsr1h+Xq07RcQnK3LdQNJvJd2l7DEen696k6QLIuKsirteWvHzmyTdERFfqVj2mxpjPmp7bX5zacF+OUxZI/e2iLghX3atbUnqsn16RPQ0LZtJOiAiluSP5xqV/1uSsobz6V7Ljlb1fT+1Yryei4ZcpexvrxFbSdo/Im7vWWD7LGVHAN8WEavzxVflTdtXJF2eN4s7SzqkolmTpJ83mA8ADCm8mwUA9XurpCt6GjNJioi/SLpc0tt6xT6j7MhCNbN6HR15k6QXS7rA9rCeL2VHl+7Ox5Wktyur42c2+kByv1R25OPoimXHSLqyogHYoDIn512KsiNkV/RqzArZfrvtObZXK9tHTys7evOqirDbJB2Zn7Y3wS88vfI2SXvkpz6+3fam/cmhwIHKGqwbe/0urpa0obKjaFL2O7u5pzGTpIh4UtJ/92OsP0p6Q6+vWX3E7i3ppp7fSz7eU6pozOu0uLIxyx2orNn9c699cJWk19l+saTVyo60nebsNNudG8wDAIYkmjMAeM5qZacbvrxk/JaSVlRZvlLZ1R0rPVLjtMDe29gm/z5bWbNS+bW7stP6VPF9qZogP5XzHElH5S/A36LsdL4fVYQ90CufIypy6Vcetl+v7EX/E8qOlO2trCH5o6RNKkL/TdKPJX1CWSP2iO0zKpqwnyk7le6NyhqGNbYvqTgy2YhtlP099P493Jqv7/kdbKvsCGhv1Zb15YmImNfr67E+YrdVdpphI+NVU+3veRtlR1B774P/n69/Sf7mwjskzZP0dUn3OvvM4b82mA8ADCmc1ggAuYh4Jr/owjtc7ip1aySNrrJ8dL7ueZuvNXSv2z2njh2p7BS/3tbl33teuG8n6Z4a2++PHyo7Ze4QSe9RdprcVRXr3yVp44rbD1bksl0/x3qfsqNl742If57OZ3uUsiN4kqSIeELSv0v6d9svl/R+Sacp+5zVF/PG4MeSfpzfd39ln0H7lbKGrRGrlT3GD/SxfnH+fYWkl1ZZX21ZM6zQc018rfF6Pju5Ua/lL1F11f5OVyv7LN/pfdxnuSRFxCJJH8uPpr5O0mck/cD24oj4nz7uCwCoQHMGAM93mqS5yo4K9L6Ufs+FL0bmF6u4XtI7bY+suEDHSGUNzNwGcrhRWQP2ioio9Zmt2couADJF+YUpquhpMIeXGTgiHrB9tbKLWewh6SsRsb5i/Z193PVqSe+1vW1EVDv6Us2myq6aWHl1wf2UXUTjwWp3yD8LNzO/CMhuVdavlfQr22/UC+cMq8dvlTWRT0TE3TXibpJ0gu2xEfGwJNneTNnfQivcLOkLtl9WccrpcEnv7BW3StnfQO991Tuult8qO23zrvzUyZryZvl221OVHRHdTRLNGQCUQHMGABUi4ob8ReU3bb9G0rmSlig7TXGypE9K+rCyCYO/KulgZReIOF1Zk/FFZU3HV1649dI5/MX2CZK+b3trZS9s/6zsyNTbJM2NiJ/njdQZkqbmTeHlypqdvSTdHRG/UnYBkWckfcL2GmUv1O/paSb78ANlV2x8WtLZJdPuUvaC/0bbX5N0f57vgRFxeB/3+a2kz0o61/Y5yj5rdoqkZZVBtm/KH9udyk6BfJuyIzPn5evPVNbM3qTsVL9XSvqosoaxURdI+riy3/FMZadcbiRpJ2VXkHx3RPxV2RUlPyXpatvT9NzVGgubmTp9U9mpnFfZnp6PNzX//s9mNyLC9q+Unap6r7IjrO9UdlXSsr6s7DTOG2x/T9nRwlHKmq5xEfGJ/Mql31Z2tPJ+ZVcdPVLZ3951VbYJAKiC5gwAeomIb9m+VdkV+L6h7Ap265R9nuYY5Rd5iIg7nM0TdaqyRsHKjmi8rddl9OvJ4ce2H1b2Av/Dyi4+sUzSDZJur4j7gu37lTUGR0h6UlnjeHW+frXtzyhrGq9X9qJ5kmof2btS2SXxfxMRK0vmuzg/WvWfyj5zNDLP97Ia97nK9nHKmor3SVqg7LNNX+oVeoOy0wpPUvZ/a5Gkz0XEd/L1f1DWQH1U2ZQFyyX9l7KGsSER8bTtA/Kxp0jaUdk+fkDZfvpHHveY7cnKGpTzlJ0K+KM83y83mkeVvHrG+46yz9z1jLeVsn1Y6XhlnzGfln+/UNnn+Pq6QE3vsZbYnpDf/2uSts7HW6Dnrsa5UtmbGFMlvUzZ6ZR3Sjo4IrrreYwAMBS5yvQpAIAhzPY7lDV3b4+IazudD8rJr2I5X9JjETG50/kAAPqPI2cAAEmS7Z0kjVN2it58GrO02f6qslMIH1J2gY9PKpt8+6BO5gUAqB/NGQCgxymSDlf2uarep8YhPaHslMkx+c93KPsMHBffAIABitMaAQAAACABTEINAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0Z2gb2xNtL233fQGgCPUJQIqoTUMPzdkAZvuJiq/1tp+quP2RFo57pO3ft2r7zWB7H9u32l5n+w7bb+50TsBQQn3qG/UJ6BxqUzHbb7Mdtv+z07kMRcM6nQDqFxEjen62vVjSJyNidu8428Mi4pl25tZJtreUdLmkf5V0iaTDJP237XERsbajyQFDBPWpOuoT0FnUptpsbyjp25Ju6XQuQxVHzgahnsPYtr9oe6Wkc6q9Y5O/K/KK/OeNbX/D9hLbq2z/yPbwOsb+uO2F+TvCi2wfUyXmZNuP2V5c+S5Vs3KQtI+kVRFxUUQ8GxH/JelRSe+tY1sAmoj6RH0CUkRt+qfPS7pa0t0NbAMNoDkbvEZL2lLSyyVNKRF/uqRXStpD0iskbSfpy3WM+4ikgyW9WNLHJZ1h+/W98toq3/4Rks60/ar+5mD7B7Z/0EcOzr96L9utvw8GQEtQn164jPoEdN5Qrk2y/XJJn5D0lToeA5qE5mzwWi+pKyL+HhFP1Qq0bUlHS/pcRKyJiHWSvibpQ/0dNCKujIgHInO9sndf3tIr7JQ8r+slXSnpA/3NISI+FRGf6iONGyWNsX2Y7Q1tHyFpJ0mb9vfxAGgJ6hP1CUjRUK5NkvSdfJwn+vsY0Dx85mzwejQi/lYydmtlLwy6s+e5pOyd3A36O6jtf5HUpexdnBfl272zImRtRDxZcfshSWOamUNErLZ9iKRvSPq+pKskzZbEFYuANFCfqE9AioZsbbL9LkkjI+JX/b0vmovmbPCKXrefVMU7s7ZHV6x7TNJTknaNiGX1Dmh7Y0kXS/qYpMsi4mnbs/T8U3hG2d6soshsL2lBs3Lokb+z9IY8r2GSHpA0s9HtAmgK6hP1CUjRUK5NkyVNyD9vJ0mbS3rW9u4RcUiD20Y/cFrj0PFHSbva3sP2JpKm9ayIiPWSfqLsHOdtJMn2drYPqLE9296k8kvSRpI2Vvbh9mfyd4L2r3Lf6bY3sv0WZedYX1RnDrWSG5+fMvRiZe9QL42Iq+rZFoCWoz5Rn4AUDaXadIqe++zaHsquKvsTZZ+BQxvRnA0REXGvsg94zpZ0n6Tec218UdL9km62/Zc87lXq2z7K3q3p/XWcpAslrZX0YWVP7kor83XLJV0g6diI6LkiUOkc8qsR/ahGficqe0fpYUnbSnpPjVgAHUR9oj4BKRpKtSki1kXEyp6vPK8nI2JNjceDFnBE7yO4AAAAAIB248gZAAAAACSA5gwAAAAAEkBzBgAAAAAJoDkDAAAAgATQnKGQ7XNt/2f+81ts31Pndn5k+5TmZgdgqKI2AUgRtQmNoDkbJGwvtv2U7Sdsr7J9ju0RzR4nIn4XEbUuE9uTz5G2n3fJ2Yg4NiK+2uycqoy9se0zbC+3vdb2D2xv2OpxAbwQtel5Y1ObgERQm/rM4zrbYXtYO8fFc2jOBpd3RcQISa+X9AZJX+odMESebCdJmiBpN2UTKr5eVfYFgLahNmWoTUBaqE0VbH9E0pB5vKmiORuEImKZpP9R9gJA+Tsgn7Z9n7JJFGX7YNu3237c9o22X9tzf9vjbc+3vc72ryRtUrFuou2lFbfH2r7E9qO2V9v+nu3XSPqRpDfl70g9nsf+8zB/fvto2/fbXmP7cttjKtaF7WNt35e/w/x92y65C94l6TsRsSYiHpX0HUmf6OduBNBk1CZqE5AiapNke3NJXZJO7OfuQ5PRnA1CtsdKOkjS/1YsfrekN0raxfbrJZ0t6RhJL5H0Y0mXOzvlZiNJsySdL2lLSRdJel8f42wg6QpJD0naQdJ2kn4ZEQslHSvppogYERFbVLnvfpK+LukDkrbNt/HLXmEHK3sn63V53AH5fbfPi+P2fe2C/Kvy9svywgOgQ6hN1CYgRdQmSdLXJP1Q0soaMWgDmrPBZVb+bsvvJV2v7InW4+v5u7VPSTpa0o8j4paIeDYizpP0d0l7518bSvpWRDwdEb+WdFsf4+0laYykEyLiyYj4W0T8vo/Y3j4i6eyImB8Rf5f078reMdqhIua0iHg8IpZImiNpD0mKiCURsUW+vJr/kXS87a1tj5Z0XL5805K5AWgualOG2gSkhdokyfYESftK+m7JXNBCnFc6uLw7Imb3se7hip9fLukI2/9WsWwjZQUjJC2LiKhY91Af2xwr6aGIeKaOXMdImt9zIyKesL1a2btIi/PFle/e/FVS2Q/qnippC0m3KyueP5E0XtIjdeQJoHHUpgy1CUjLkK9Ntl8k6QeSjo+IZ/pxJiRahCNnQ0dl0XhY0qn5uyg9X5tGxC8krZC0Xa/zlPs6DP6wpO1d/cOyUWVZpeXKip0kyfZmyk4VWFb0QIpExFMR8ZmI2C4ixklaLak7Ip5tdNsAmo7aRG0CUjRUatOLlV2o6Fe2V+q5o35Lbb+lwW2jDjRnQ9NPJB1r+43ObGb7nbZHSrpJ0jOSjrM9zPZ7lR2Gr+ZWZUXptHwbm9jeN1+3StlnKTbq474/l/Rx23vY3ljZqQS3RMTiRh+c7e1sj8kf296STlH2IVcAaaM2AUjRYK5Nf1Z2VG6P/OugfPmekm5pcNuoA83ZEBQR85SdP/09SWsl3S/pyHzdPyS9N7+9VtIHJV3Sx3aeVXb1sVdIWiJpaR4vSddJukvSStuPVbnvtcpemFysrFDtJOlDZfLPP9j6RI0Ptu4k6UZJT0o6T9JJEXF1mW0D6BxqE4AUDebaFJmVPV+SHs1XrcofG9rMzz9FFgAAAADQCRw5AwAAAIAE0JwBAAAAQAJozgAAAAAgATRnAAAAAJAAmjMAAAAASEC1SfBaxjaXhgQGoYhwcVS6qE3AoPVYRGzd6SQaQX0CBqe+Xjs1dOTM9oG277F9v+2TGtkWADQT9QmApIc6nUBv1CYAtdTdnNneQNL3Jf2LpF0kHWZ7l2YlBgD1oj4BSBG1CUCRRo6c7SXp/ohYlM8g/ktJhzQnLQBoCPUJQIqoTQBqaqQ5207SwxW3l+bLnsf2FNvzbM9rYCwA6I/C+kRtAtABvHYCUFMjFwSp9iG2F3xoNSLOlHSmxIdaAbRNYX2iNgHoAF47AaipkSNnSyWNrbj9MknLG0sHAJqC+gQgRdQmADU10pzdJmln2zva3kjShyRd3py0AKAh1CcAKaI2Aaip7tMaI+IZ25+RdJWkDSSdHRF3NS0zAKgT9QlAiqhNAIo4on2nMnPeNDA4MQk1gER1R8SETifRCOoTMDi1ZBJqAAAAAEBz0JwBAAAAQAJozgAAAAAgATRnAAAAAJAAmjMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACaA5AwAAAIAE0JwBAAAAQAJozgAAAAAgATRnAAAAAJAAmjMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACaA5AwAAAIAE0JwBAAAAQAJozgAAAAAgATRnAAAAAJAAmjMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACaA5AwAAAIAE0JwBAAAAQAJozgAAAAAgATRnAAAAAJAAmjMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACaA5AwAAAIAEDOt0AgAAAADQl2nTphXGdHV1FcZMmjSpMGbu3LklMmodjpwBAAAAQAIaOnJme7GkdZKelfRMRExoRlIA0CjqE4AUUZsA1NKM0xonRcRjTdgOADQb9QlAiqhNAKritEYAAAAASECjzVlIutp2t+0p1QJsT7E9z/a8BscCgP6oWZ+oTQA6hNdOAPrU6GmN+0bEctvbSLrG9t0RcUNlQEScKelMSbIdDY4HAGXVrE/UJgAdwmsnAH1q6MhZRCzPvz8i6VJJezUjKQBoFPUJQIqoTQBqqbs5s72Z7ZE9P0vaX9KCZiUGAPWiPgFIEbUJQJFGTmt8qaRLbfds5+cR8dumZAUAjaE+AUgRtQnoZeLEiYUxZSaYHizqbs4iYpGk1zUxFwBoCuoTgBRRmwAU4VL6AAAAAJAAmjMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACaA5AwAAAIAE0JwBAAAAQAIcEe0bzG7fYIPM3nvvXRgzderUwpj3v//9zUinKT74wQ8Wxlx00UWFMRtvvHFhzIgRIwpjVq9eXRiD6iLCnc6hEdSmgeFnP/tZYcz73ve+muuHDx/erHQKdXd3F8ZcfPHFhTFXXHFFzfULFiwondMQ1B0REzqdRCOoTwPDtGnTktjGQNSsXmTu3LmFMZMmTWrKWM3Q12snjpwBAAAAQAJozgAAAAAgATRnAAAAAJAAmjMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACaA5AwAAAIAEDOt0AsjsuuuuNdfPmjWrcBtbb711YcySJUsKY2bPnt3wWAcffHDhNo466qjCmD/84Q+FMdddd11hzLp16wpj3vCGNxTGAHi+N7/5zYUxZSacL5o8WpJGjx5dGPPkk0/WXH/PPfcUbuOnP/1pYcw+++xTGLPvvvsWxpx66qmFMR/5yEdqrn/rW99auI21a9cWxgCoX1dXV2HM9OnT25BJWubMmdO2sa6//vq2jdVKHDkDAAAAgATQnAEAAABAAmjOAAAAACABNGcAAAAAkACaMwAAAABIAM0ZAAAAACSA5gwAAAAAEsA8Z23w6le/ujDmyiuvrLm+zBxm3d3dhTEHHHBAYUyZ+XC+9KUv1VxfZp6zMk4//fTCmJ133rkpYxXlfMUVVzRlHCAVw4cPL4z59Kc/XXP9KaecUriNTTfdtDDm5ptvLowpM+/hV7/61ZrrH3744cJtlKmBM2fOLIwZO3ZsYUyZOYB22WWXmutPOOGEwm2cfPLJhTEAqps2bVqnU0hW0b6ZOHFiU8aZO3duw7kMFBw5AwAAAIAE0JwBAAAAQAJozgAAAAAgATRnAAAAAJAAmjMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACWAS6gaNGzeuMObaa68tjBk9enTN9RdddFHhNoomi5XKTa769re/vTDmC1/4QmFMkT333LMwZsSIEQ2PU9aoUaPaNhbQamX+nstM6rn77rvXXB8RhdsomhhaGjyTh1YqU+N23HHHhsc5+uijC2MefPDBwphf//rXhTFl/ocAA0mZ2tPV1dX6RBJUZgLpdu2b6dOnt2WcFHDkDAAAAAASUNic2T7b9iO2F1Qs29L2Nbbvy79zyAFA21GfAKSI2gSgXmWOnJ0r6cBey06SdG1E7Czp2vw2ALTbuaI+AUjPuaI2AahDYXMWETdIWtNr8SGSzst/Pk/Su5ubFgAUoz4BSBG1CUC96r0gyEsjYoUkRcQK29v0FWh7iqQpdY4DAP1Vqj5RmwC0Ga+dABRq+dUaI+JMSWdKku3iy3oBQBtQmwCkivoEDF31Xq1xle1tJSn//kjzUgKAhlCfAKSI2gSgUL3N2eWSjsh/PkLSZc1JBwAaRn0CkCJqE4BCLppA1PYvJE2UtJWkVZK6JM2SdKGk7SUtkXRoRPT+4Gu1bQ2oQ/Obb755Ycyf/vSnwpgxY8YUxtx888011x900EGF22jW5KBTphSf5j516tSa61/5ylc2JZd169Y1ZTsjR44sjHnRi5j2r14R4U6M26z6NNBqU5m/50suuaQwZvLkyYUx11xzTc31n/vc5wq3UaZODjQveclLCmPuvffewpgyk4UXsYuffrfffnthzPjx4xvOJUHdETGh3YMO5ddOA82cOXMKY8pMxlxGmedqSop6hGaZNGlSYczcuXNbn0ib9fXaqfAzZxFxWB+riv+rA0ALUZ8ApIjaBKBeHCoAAAAAgATQnAEAAABAAmjOAAAAACABNGcAAAAAkACaMwAAAABIAM0ZAAAAACSA5gwAAAAAElA4z9lQNnz48MKY0aNHF8aUmcRvxowZNdc3a4LpMs4///zCmAMPPLDm+p133rlwG7NmzSqMOfHEEwtjrr322sKYESNGFMbsueeeNdd3d3cXbgNoh/e///2FMfvtt19hzLx58wpjiiaZHmgTTG+44YaFMe985zsLYz784Q8XxmyxxRaFMe2a5LVM/SozIXY7/xcBjWrnBNMDbZLkZj3uImX2y0Dbd63GkTMAAAAASADNGQAAAAAkgOYMAAAAABJAcwYAAAAACaA5AwAAAIAE0JwBAAAAQAJozgAAAAAgATRnAAAAAJAAJqGu4TWveU3bxho5cmTN9ZtvvnnhNspMmn344Yc3JWbXXXetuf6CCy4o3Mbxxx9fGFNmwtMyMWPHji2M2WWXXWquZxJqpGKTTTYpjFm/fn1hzC9+8YvCmJQmmR49enRhzOTJk2uuP+mkkwq3UVQLJMl2YUy7JphevHhxYUzRZOKStG7duiZkA7RP0UTK7ZpoWZImTZrUtrGKlHncZSboLqNoAumU9stAwZEzAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAHMc1bDFlts0baxzjnnnJrrly9fXriNMWPGFMaUmZvnwgsvLIzp6uqquf6yyy4r3EYZo0aNakpMGePGjWvKdoBW23HHHQtjnn766cKYhx9+uBnpFDr00EMLY4466qjCmP33378wphlzi82ePbswZuXKlYUxZf6HHHzwwWVSqunEE08sjGEOMww07Zyrq0jRXF6paef8btdff33bxhoqOHIGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASwCTUNVx66aWFMQcccEBhzAknnFAY04yJU7/73e8WxsyYMaPhcdpp++23b0pMGfvss09TtgO0WplJh3fffffCmDITzrfLsmXLCmOuueaawphf//rXNdevXbu24W1I0m677VYYc+655xbG2C6MKdo38+fPL9wGkJIykyR3dXW1PpGSyuTbjNdxA9FAm6B7IODIGQAAAAAkoLA5s3227UdsL6hYNs32Mtu3518HtTZNAHgh6hOAFFGbANSrzJGzcyUdWGX5GRGxR/71m+amBQClnCvqE4D0nCtqE4A6FDZnEXGDpDVtyAUA+oX6BCBF1CYA9WrkM2efsX1Hfuh+VF9BtqfYnmd7XgNjAUB/FNYnahOADuC1E4Ca6m3OfihpJ0l7SFohaWZfgRFxZkRMiIgJdY4FAP1Rqj5RmwC0Ga+dABSqqzmLiFUR8WxErJf0E0l7NTctAKgP9QlAiqhNAMqoqzmzvW3FzfdIWtBXLAC0E/UJQIqoTQDKKJyE2vYvJE2UtJXtpZK6JE20vYekkLRY0jGtSzFts2fPbkoMqiszYeyaNcWfuR41qs9T+/+pu7u7VE5IB/Wpb4cffnhhzJFHHtnwdso8R6+88sqmxNx9992FMe3yta99rTBm/PjxhTFlJq4tmmR60aJFhdtAew3l2lRmwuY5c+a0PhG0RTN+l7abkMngUdicRcRhVRaf1YJcAKBfqE8AUkRtAlCvRq7WCAAAAABoEpozAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACCuc5AzqpzOTRZWLKWLhwYVO2A6Rg9erVhTEzZ85sSsxgc8YZZxTGHHzwwYUxZSaYvu222wpjykwoDqSizCTUzTJ37tya66dPn97wNqRyj6kZj7urq6vhbaSmzP7F83HkDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIBJqAEAQ8ahhx5aGLPffvs1Zazu7u7CmK9//euFMevWrWtGOkBblJl0uMxky5MmTWrKWM1QZpxmTGbdrEmoy+RSZv+iMzhyBgAAAAAJoDkDAAAAgATQnAEAAABAAmjOAAAAACABNGcAAAAAkACaMwAAAABIAM0ZAAAAACSA5gwAAAAAEsAk1BjwbHc6BQCJGDlyZM31p512WuE2dtxxx6bkctxxxxXG3HzzzU0ZC0hFmQmQh+r/7WZNMl1k+vTpbRkHrcGRMwAAAABIAM0ZAAAAACSA5gwAAAAAEkBzBgAAAAAJoDkDAAAAgATQnAEAAABAAmjOAAAAACABzHOGAS8imrKd8ePH11x//vnnN2UcAK1zyy231Fy/ww47FG6jTE0544wzCmNuvfXWwhgAg8O0adMKYyZOnNjwOGXmkSsTg3Rx5AwAAAAAElDYnNkea3uO7YW277J9fL58S9vX2L4v/z6q9ekCQIbaBCBV1CcA9Spz5OwZSZ+PiNdI2lvSp23vIukkSddGxM6Srs1vA0C7UJsApIr6BKAuhc1ZRKyIiPn5z+skLZS0naRDJJ2Xh50n6d0tyhEAXoDaBCBV1CcA9erXBUFs7yBpvKRbJL00IlZIWRGyvU0f95kiaUqDeQJAn6hNAFJFfQLQH6WbM9sjJF0s6bMR8Rfbpe4XEWdKOjPfRnMuqwcAOWoTgFRRnwD0V6mrNdreUFlxuSAiLskXr7K9bb5+W0mPtCZFAKiO2gQgVdQnAPUoc7VGSzpL0sKI+GbFqsslHZH/fISky5qfHgBUR20CkCrqE4B6lTmtcV9JH5V0p+3b82UnSzpN0oW2j5K0RNKhLckQaJMHH3yw0ymgf6hNQ8zMmTMLY3baaaeGx5kxY0ZTYtavX99wLhiwqE9DTFdXV1vGmTRpUlvGQecUNmcR8XtJfZ0kPbm56QBAOdQmAKmiPgGoV6nPnAEAAAAAWovmDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQgDKTUANDwuOPP97pFIAha9y4cYUxhxxySGHMsGG1/63dddddhdu44IILCmPWrl1bGANgcJg4cWJbxpk+fXpbxkHaOHIGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASwCTUGPBsdzoFADUMHz68MGbWrFmFMWUmqi5ywgknFMYsWLCg4XEADB5z5sxpynbmzp1bc/20adOaMg4GNo6cAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAmgOQMAAACABDAJNZK2fPnywphly5YVxowZM6YZ6QCow9SpUwtjdt5556aM9e1vf7vm+htvvLEp4wAYOoomj5akiRMnFsZMnz698WQw6HHkDAAAAAASQHMGAAAAAAmgOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJMAR0b7B7PYNhiHjpptuKozZa6+9CmMeeOCBmus/9KEPFW5j/vz5hTGDUUS40zk0gtpUv2222aYwpru7uzCmzETx99xzT2FM0fP0jjvuKNwGBpXuiJjQ6SQaQX0CBqe+XjsVHjmzPdb2HNsLbd9l+/h8+TTby2zfnn8d1OykAaAv1CYAqaI+AajXsBIxz0j6fETMtz1SUrfta/J1Z0TEN1qXHgD0idoEIFXUJwB1KWzOImKFpBX5z+tsL5S0XasTA4BaqE0AUkV9AlCvfl0QxPYOksZLuiVf9Bnbd9g+2/aoPu4zxfY82/MaSxUAqqM2AUgV9QlAf5RuzmyPkHSxpM9GxF8k/VDSTpL2UPbu0Mxq94uIMyNiwkD/QC6ANFGbAKSK+gSgv0o1Z7Y3VFZcLoiISyQpIlZFxLMRsV7STyQVXw4PAJqI2gQgVdQnAPUoc7VGSzpL0sKI+GbF8m0rwt4jaUHz0wOA6qhNAFJFfQJQrzJXa9xX0kcl3Wn79nzZyZIOs72HpJC0WNIxLcgPKDRjxozCmGOPPbYwZtGiRTXXD9U5zBJGbUrAJptsUhizYsWKwpgy85xNnjy5KWMBbUB9AlCXMldr/L2kapOk/ab56QBAOdQmAKmiPgGoV7+u1ggAAAAAaA2aMwAAAABIAM0ZAAAAACSA5gwAAAAAEkBzBgAAAAAJoDkDAAAAgATQnAEAAABAAhwR7RvMbt9gANomIqrN5zNgUJuAQas7IiZ0OolGUJ+Awamv104cOQMAAACABNCcAQAAAEACaM4AAAAAIAE0ZwAAAACQAJozAAAAAEgAzRkAAAAAJIDmDAAAAAASQHMGAAAAAAkY1ubxHpP0UMXtrfJlAwX5thb5tlar8n15C7bZbr1rk8Tvt9XIt7XINzMY6xO/29Yi39Yi30yftckRnZt43va8iJjQsQT6iXxbi3xba6Dl22kDbX+Rb2uRb2sNtHw7aaDtK/JtLfJtrU7ky2mNAAAAAJAAmjMAAAAASECnm7MzOzx+f5Fva5Fvaw20fDttoO0v8m0t8m2tgZZvJw20fUW+rUW+rdX2fDv6mTMAAAAAQKbTR84AAAAAAKI5AwAAAIAkdKw5s32g7Xts32/7pE7lUZbtxbbvtH277Xmdzqc322fbfsT2goplW9q+xvZ9+fdRncyxUh/5TrO9LN/Ht9s+qJM5VrI91vYc2wtt32X7+Hx5kvu4Rr7J7uNUUJuai9rUWtSmoYX61FzUp9ahNjWQSyc+c2Z7A0n3SnqHpKWSbpN0WET8qe3JlGR7saQJEZHkxHm23yrpCUk/i4jd8mUzJK2JiNPyIj4qIr7YyTx79JHvNElPRMQ3OplbNba3lbRtRMy3PVJSt6R3SzpSCe7jGvl+QInu4xRQm5qP2tRa1Kahg/rUfNSn1qE21a9TR872knR/RCyKiH9I+qWkQzqUy6AQETdIWtNr8SGSzst/Pk/ZH1kS+sg3WRGxIiLm5z+vk7RQ0nZKdB/XyBe1UZuajNrUWtSmIYX61GTUp9ahNtWvU83ZdpIerri9VOkX55B0te1u21M6nUxJL42IFVL2Rydpmw7nU8ZnbN+RH7pP4lB3b7Z3kDRe0i0aAPu4V77SANjHHURtao/knzdVJP+8oTYNetSn9kj+uVNF0s8dalP/dKo5c5VlqV/Tf9+IeL2kf5H06fzQMprrh5J2krSHpBWSZnY0mypsj5B0saTPRsRfOp1PkSr5Jr+PO4zahGqSf95Qm4YE6hOqSfq5Q23qv041Z0slja24/TJJyzuUSykRsTz//oikS5WdXpC6Vfk5tD3n0j7S4XxqiohVEfFsRKyX9BMlto9tb6jsCXtBRFySL052H1fLN/V9nABqU3sk+7ypJvXnDbVpyKA+tUeyz51qUn7uUJvq06nm7DZJO9ve0fZGkj4k6fIO5VLI9mb5hwNlezNJ+0taUPteSbhc0hH5z0dIuqyDuRTqebLm3qOE9rFtSzpL0sKI+GbFqiT3cV/5pryPE0Ftao8knzd9Sfl5Q20aUqhP7ZHkc6cvqT53qE0N5NKJqzVKkrNLUX5L0gaSzo6IUzuSSAm2xyl7x0eShkn6eWr52v6FpImStpK0SlKXpFmSLpS0vaQlkg6NiCQ+SNpHvhOVHTYOSYslHdNzXnKn2X6zpN9JulPS+nzxycrOR05uH9fI9zAluo9TQW1qLmpTa1GbhhbqU3NRn1qH2tRALp1qzgAAAAAAz+nYJNQAAAAAgOfQnAEAAABAAmjOAAAAACABNGcAAAAAkACaMwAAAABIAM0ZAAAAACSA5gwAAAAAEvB/WkH2BpkeWBAAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "from random import shuffle\n", "\n", "\n", "'''\n", "Step 1: (same step)\n", "'''\n", "# Use data with only 4 and 9 as labels: which is hardest to classify\n", "label_1, label_2 = 4, 9\n", "\n", "# MNIST training data\n", "train_set = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "\n", "# Use data with two labels\n", "idx = (train_set.targets == label_1) + (train_set.targets == label_2)\n", "train_set.data = train_set.data[idx]\n", "train_set.targets = train_set.targets[idx]\n", "train_set.targets[train_set.targets == label_1] = -1\n", "train_set.targets[train_set.targets == label_2] = 1\n", "\n", "# MNIST testing data\n", "test_set = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor())\n", "\n", "# Use data with two labels\n", "idx = (test_set.targets == label_1) + (test_set.targets == label_2)\n", "test_set.data = test_set.data[idx]\n", "test_set.targets = test_set.targets[idx]\n", "test_set.targets[test_set.targets == label_1] = -1\n", "test_set.targets[test_set.targets == label_2] = 1\n", " \n", "\n", "'''\n", "Step 2: (same step)\n", "'''\n", "class LR(nn.Module) :\n", " def __init__(self, input_dim=28*28) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, 1, bias=False)\n", "\n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " return self.linear(x.float().view(-1, 28*28))\n", "\n", "'''\n", "Step 3: (LOOK HERE)\n", "'''\n", "model = LR() \n", "\n", "def logistic_loss(output, target):\n", " return -torch.nn.functional.logsigmoid(target*output)\n", "\n", "loss_function = logistic_loss \n", "optimizer = torch.optim.SGD(model.parameters(), lr=255*1e-4) # LR scaled up by 255\n", "\n", " \n", "'''\n", "Step 4: Train model with SGD (LOOK HERE)\n", "'''\n", "# Use DataLoader class (Press Ctrl+/ to comment in/out)\n", "\n", "# 1. SGD\n", "# from torch.utils.data import RandomSampler\n", "# train_loader = DataLoader(dataset=train_set, batch_size=1, sampler=RandomSampler(train_set, replacement=True))\n", "\n", "# 2. cyclic SGD\n", "# train_loader = DataLoader(dataset=train_set, batch_size=1)\n", "\n", "# 3. shuffled cyclic SGD\n", "train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True)\n", "\n", "# Train the model\n", "import time\n", "start = time.time()\n", "iter_count = 0\n", "\n", "for image,label in train_loader :\n", " iter_count += 1\n", " if iter_count > 1000:\n", " break\n", "\n", " # Clear previously computed gradient\n", " optimizer.zero_grad()\n", "\n", " # then compute gradient with forward and backward passes\n", " train_loss = loss_function(model(image), label.float())\n", " train_loss.backward()\n", "\n", " # perform SGD step (parameter update)\n", " optimizer.step()\n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end-start}\")\n", "\n", "'''\n", "Step 5: Test model (LOOK HERE)\n", "'''\n", "test_loss, correct = 0, 0\n", "misclassified_ind = []\n", "correct_ind = []\n", "\n", "# Test data\n", "test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)\n", "# no need to shuffle test data\n", "\n", "# Evaluate accuracy using test data\n", "for ind, (image, label) in enumerate(test_loader) :\n", "\n", " # Forward pass\n", " output = model(image)\n", "\n", " # Calculate cumulative loss\n", " test_loss += loss_function(output, label.float()).item()\n", "\n", " # Make a prediction\n", " if output.item() * label.item() >= 0 : \n", " correct += 1\n", " correct_ind += [ind]\n", " else:\n", " misclassified_ind += [ind]\n", "\n", "# Print out the results\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_loader), correct, len(test_loader),\n", " 100. * correct / len(test_loader)))\n", "\n", "'''\n", "Step 6: (same step)\n", "''' \n", "# Misclassified images\n", "shuffle(misclassified_ind)\n", "fig = plt.figure(1, figsize=(15, 6))\n", "fig.suptitle('Misclassified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[misclassified_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[misclassified_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_2))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_1))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()\n", "\n", "# Correctly classified images\n", "shuffle(correct_ind)\n", "fig = plt.figure(2, figsize=(15, 6))\n", "fig.suptitle('Correctly-classified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[correct_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[correct_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_1))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_2))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DataLoader scales the image by a factor of 255 and converts the data type to a float." ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18,\n", " 18, 18, 126, 136, 175, 26, 166, 255, 247, 127, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253,\n", " 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253,\n", " 253, 253, 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253,\n", " 198, 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205,\n", " 11, 0, 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190,\n", " 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253,\n", " 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241,\n", " 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81,\n", " 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148,\n", " 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253,\n", " 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253,\n", " 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195,\n", " 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n", " dtype=torch.uint8)\n" ] } ], "source": [ "train_set = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "print(train_set.data[0])\n", "# train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True)\n", "# print(next(iter(train_loader)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data loader creates a iterator that end when all data has been processed" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time ellapsed in training is: 4.084276914596558\n", "[Test set] Average loss: 0.1153, Accuracy: 1907/1991 (95.78%)\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "from random import shuffle\n", "\n", "\n", "'''\n", "Step 1: (same step)\n", "'''\n", "# Use data with only 4 and 9 as labels: which is hardest to classify\n", "label_1, label_2 = 4, 9\n", "\n", "# MNIST training data\n", "train_set = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "\n", "# Use data with two labels\n", "idx = (train_set.targets == label_1) + (train_set.targets == label_2)\n", "train_set.data = train_set.data[idx]\n", "train_set.targets = train_set.targets[idx]\n", "train_set.targets[train_set.targets == label_1] = -1\n", "train_set.targets[train_set.targets == label_2] = 1\n", "\n", "# MNIST testing data\n", "test_set = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor())\n", "\n", "# Use data with two labels\n", "idx = (test_set.targets == label_1) + (test_set.targets == label_2)\n", "test_set.data = test_set.data[idx]\n", "test_set.targets = test_set.targets[idx]\n", "test_set.targets[test_set.targets == label_1] = -1\n", "test_set.targets[test_set.targets == label_2] = 1\n", " \n", "\n", "'''\n", "Step 2: (same step)\n", "'''\n", "class LR(nn.Module) :\n", " '''\n", " Initialize model\n", " input_dim : dimension of given input data\n", " '''\n", " # MNIST data is 28x28 images\n", " def __init__(self, input_dim=28*28) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, 1, bias=False)\n", "\n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " return self.linear(x.float().view(-1, 28*28))\n", "\n", "'''\n", "Step 3: (same step)\n", "'''\n", "model = LR() # Define a Neural Network Model\n", "\n", "def logistic_loss(output, target):\n", " return -torch.nn.functional.logsigmoid(target*output)\n", "\n", "loss_function = logistic_loss # Specify loss function\n", "optimizer = torch.optim.SGD(model.parameters(), lr=255*1e-4) # specify SGD with learning rate\n", "\n", " \n", "'''\n", "Step 4: Train model with SGD (LOOK HERE)\n", "'''\n", "\n", "# shuffled cyclic SGD\n", "train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "# Train the model (single epoch)\n", "for image, label in train_loader :\n", "\n", " # Clear previously computed gradient\n", " optimizer.zero_grad()\n", "\n", " # then compute gradient with forward and backward passes\n", " train_loss = loss_function(model(image), label.float())\n", " train_loss.backward()\n", "\n", " # perform SGD step (parameter update)\n", " optimizer.step() \n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end-start}\")\n", "\n", "'''\n", "Step 5: (same step)\n", "'''\n", "test_loss, correct = 0, 0\n", "misclassified_ind = []\n", "correct_ind = []\n", "\n", "# Test data\n", "test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)\n", "# no need to shuffle test data\n", "\n", "# Evaluate accuracy using test data\n", "for ind, (image, label) in enumerate(test_loader) :\n", "\n", " # Forward pass\n", " output = model(image)\n", "\n", " # Calculate cumulative loss\n", " test_loss += loss_function(output, label.float()).item()\n", "\n", " # Make a prediction\n", " if output.item() * label.item() >= 0 : \n", " correct += 1\n", " correct_ind += [ind]\n", " else:\n", " misclassified_ind += [ind]\n", "\n", "# Print out the results\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_loader), correct, len(test_loader),\n", " 100. * correct / len(test_loader)))\n", "\n", "'''\n", "Step 6: (same step)\n", "''' \n", "# Misclassified images\n", "shuffle(misclassified_ind)\n", "fig = plt.figure(1, figsize=(15, 6))\n", "fig.suptitle('Misclassified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[misclassified_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[misclassified_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_2))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_1))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()\n", "\n", "# Correctly classified images\n", "shuffle(correct_ind)\n", "fig = plt.figure(2, figsize=(15, 6))\n", "fig.suptitle('Correctly-classified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[correct_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[correct_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_1))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_2))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running multiple epochs" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time ellapsed in training is: 12.343320608139038\n", "[Test set] Average loss: 0.0893, Accuracy: 1922/1991 (96.53%)\n", "\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2cAAAFoCAYAAADTgoOZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAq4UlEQVR4nO3deZhcZZn38d8PQoIsMhAkQGQZFGEEkSWgQkaTV0dkM8IMDBjZDETG5BWFmUsG1ARwgXkVHZ1RhAHBERBEDCg4w2JAUIgswxYZ2a4ACVmAoCRhEZL7/eOchqLpqud0rU+6v5/r6qu7zrnrnLtOd91dd52nzuOIEAAAAACgt9bodQIAAAAAAJozAAAAAMgCzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAUAP2D7KdpRf7xhg/YSa9R+qWX6B7XltzmXrcj9HtXO7Fffd9zgn1Cxbw/a3bC+0vcr2rE7kWG5vZsX8Bvo6ZqD8AQBo1oheJwAAw9wySYdL+mK/5UeU69bvt/x0Sf/ahby65S5J75P0+5plfyfpeEknSrpV0jOSFpZxj3Q7wdJnJN3eb9kjkl7SG/MHAKApNGcA0FtXSPqE7S9FREiS7TdJ+ltJP5V0VG1wRPSqOemIiHhO0m39Fv9V+f1bEbGqZnn/uG56ICLq7b/redkeFREvdXu/AIDOYlgjAPTWf0raStL4mmUHSlpTRXP2Ov2HNdoeYft024/YftH207ZvsT2+3/2OtX2X7RdsP2v7Jtt71kvK9u62L7c9v7zPH2x/tWwca+P2tv0b23+yvbyM+1LN+nfY/pntJWV+j9v+ie0R5frXDQssH9vM8u4r+4Yy1hvWaPsDtm+wvcz2Ctv/bXvHfjFr2v5yOUzyeds32t6h3mMfjDrDMvvv71e2t+8/jLLeENUyvxsH2MdBts+1/ZSkxTXrj7V9T83v/zzbG/Xb5vG2H6j5/d9h+8B2HAMAQPtw5gwAeusxSb9WMbTx5nLZEZJ+Jml5hft/XtLnJJ0i6W5Jb5Y0TtKrL85tf13FEMHzJM2QtErSeyVtKem3dba7Zbm9C1QMr9xB0pckbSPp0HK720i6StLlKoZb/lnStmVMn19I+qOkf5D0tKSxkvZV/TcHD1QxhPAoFcMFpWL44Lr9A23vJ+lKSVdL+kTN8bjZ9k4R8US5bKakkyWdJelaFcfnqjr7r2eNvoayFBGxsk7sqeX+/p+k6yXt2sT+BvIdSb9U8beytiTZPkPF7/bbkv5JxfH9sqQdbe8ZESttT5b0DUmnqfgbe5OknVTzNwIAyAPNGQD03g8lfcP2ZyRtKOlDkvapeN/3Sbo2Imo/h/bzvh9sv11F8/bNiDihJubqRhuNiFfP2tm2pN9Iek7SD21Pi4hnVDQdIyX9Qzk8UZJ+VXO/jVU0a5MiorY5ubjBfv/H9oLy51eHC9p+Q3Om4rN3N0XEpJq42ZIeVdGwfNb2huXjPyci/rEMu9b2SklnNDoG/fx3v9sLJL21f1C5v89KOjsiPl8uvs72yyoapFb8LiKOqdnX1ioaslMj4rSa5Q9KukXSAZJmqfgbubc2RtI1LeYCAOgAhjUCQO/9RNIoFS+mJ0taJOmGive9XdK+tr9ie7ztkf3Wf0hFrT9nMAnZfrPtM233XfTiZRVDMK2i4ZKKM2svS/qx7b+zvUm/zTyjolE6oxx6t63apNzW2yRdVA7tHFGe2XpexUVE3l+GvkvFWbfL+m3ix4Pc5TRJu9d87Vsnrm9/P+m3/PJB7m8gP+t3+29U/G77H4M5KhrpvmNwu6SdbX/H9odsr9OGXAAAHUBzBgA9FhHLVJzhOFzFkMaL+l0Io5Gvqhiq+FEVQ9aesf2D8qyVJI0uv88fZFo/kHSciuFyf6OiIZlWrlu7zPthSXur+F/yn5IW2Z5j+wPl+ijve4ekr0l60Pajtv9hkLkMpK8RPE9Fg1j7tb9ee9ybld8X6/X63055MCLuqPm6t05c3/6WtLi/gSzsd7vvGDysNx6DN+u1Y/BDFcNK36PiDOBS21eUZ94AABlhWCMA5OGHKoYariHpsKp3ioiXJZ0p6Uzbm6poTM6StI6kv1fxOS+p+CzSH6ps0/bakiZJmlk7XNL2uwbY/2xJs22PkrSXis81XW1764h4OiIelXREOTTy3ZKmS/qu7XkR8cuqj3MAz5Tf/1nF57r6+3P5va+hGSNpbs36MS3su5G+/W1SYX8vqhgW2t9ovfb4akW/230xH5b07ADxz0ivNsnfl/T9ctjlh1UMsbxURcMGAMgEzRkA5OE6FUPv/hgRc1PBA4mIRZL+w/a+kvquWHi9iguATFXxOawqRqm4WuTL/ZYf1WDfL0n6le31VFyk4y/1WmPY1yDcbfsESVPK/Fppzv4gaZ6kHSKi0WfH7pW0QtIhqvk8nMqLmnTAfeX+DpY0u2b5wQPEPiZpjO2NI+JpSbL9Nknbqf6FWmpdp+J3u2VEXFcluYh4VtKltt8j6VNV7gMA6B6aMwDIQHnlv8pnzPrYvlLSPSomc35W0i6SPqLiTIki4hHb35R0gu31VVw1cKWkPST9b0RcOkAuf7J9m6QTbS9U0WR9UsXZt9p9H6fic03XSHpC0sYqzmQ9Kel+2zupuGjHpSqG3q2posF7Ra9vlAYtIsL2NElXlp+zu6zMc4ykPSU9HhFnRcQfy8d/iu1lKq7WuLuKBrHtIuJZ29+SdHK5v76rNfbtr3a46k9UXOXyIttn6bXj97QqKH+3Z0r6N9vbSbpJxdm4LVQMJ/2PiJht+xwVV9y8VcVwy3eoGEJ7bSuPFQDQfjRnALB6+7WKszLTVAxlfFzSv0j6Sl9ARPyj7YclfVrSkSrO7Nyrxi/OD5P0PUn/LukFFc3P8Soujd/nHhVXlfyaimF8S1VcJXByRLxge1GZzwkqrmz4ooozS/tHxJ0tPericV1j+/0qphH4DxWXiF+kYlLo2qZzpooLmRyjYljlHBUXX2nqDGUFM8r9TVExLcAcFU3pbyT9qSb/h23/nYpL38+S9KCKY3Vy1R1FxMm2H1Dx+5+mYujjEyouKPNQGfYbSUeraMg2UNE8/6jMEwCQERcjTQAAQKfYPlhFg/v+iLg5FQ8AGJ5ozgAAaKPy81z7qThj9qKk3SSdpOJzcnsG/3gBAHUwrBEAgPZaruKzeNNUXNJ+iYqzZv9MYwYAaIQzZwAAAACQASahBgAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjN0je0Jtud3+74AkEJ9ApAjatPwQ3O2GrO9vOZrle0Xam5P7uB+j7J9S6e23062P2A7bH+517kAwwn1KY36BHQftSmN2tRbI3qdAJoXEev1/Wx7nqRjIuL6/nG2R0TEK93MLQe215L0r5Lm9DoXYLihPjVGfQJ6g9rUGLWp9zhzNgT1nca2/XnbiyT9YKB3bMp3Rd5e/jzK9tdtP257se2zbb+piX0fbfsB28tsP2r7UwPEnGz7advzat+lalcONU6UdK2k/21hGwDaiPr0KuoTkBFq06uoTT1GczZ0bSppI0lbSZpaIf5MSe+QtLOkt0saK+lLTex3iaT9Jb1Z0tGSvml71355bVxu/0hJ59jebrA52P6u7e/WS8L2VpI+Kem0Jh4DgM6iPlGfgBxRm6hNPUdzNnStkjQjIl6KiBcaBdq2pGMlfS4ilkbEMklflXToYHcaEVdHxCNRuEnFuy9/3S/si2VeN0m6WtIhg80hIj4dEZ9ukMq3y/0sH+xjANBx1CfqE5AjahO1qef4zNnQ9VREvFgx9i2S1pF0Z/E8lyRZ0pqD3antfSTNUPEuzhrldu+rCXk2IlbU3H5M0uZtzuEASetHxKWDvS+ArqA+UZ+AHFGbqE09R3M2dEW/2ytUPIElSbY3rVn3tKQXJO0QEQua3aHtUZJ+KukISVdGxMu2Z6koFH02tL1uTZHZUtL97cqh9EFJ48ox45K0gaSVtt8VEZNa3DaA1lGfqE9AjqhN1KaeY1jj8HGPpB1s72x7bUkz+1ZExCpJ56oY47yJJNkea3vvBtuz7bVrvySNlDRK0lOSXinfCfrwAPc91fZI23+tYoz1T5rMoZ4v6rXx1ztLuqrc9tFNbAtA51GfqE9AjqhN1KauozkbJiLiQRUf8Lxe0kOS+s+18XlJD0u6zfZzZdx2qm9PFe/W9P/6jKTLJD0r6eMqnty1FpXrnpR0kaTjIqLvikCVcyivRnR2nce6LCIW9X2Vea2IiKUNHg+AHqE+UZ+AHFGbqE294Ij+Z3ABAAAAAN3GmTMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5Q5LtC2x/ufz5r23/ocntnG37i+3NDsBwRW0CkCNqE1pBczZE2J5n+wXby20vtv0D2+u1ez8RcXNENLpMbF8+R9l+3SVnI+K4iDi93TkNsO9Rtr9p+0nbz9r+ru21Or1fAG9EbXrdvqlNQCaoTa/bN7UpIzRnQ8sBEbGepF0l7S7pC/0DbI/oelbdd5KkcZJ2VDGh4q4a4FgA6BpqU4HaBOSF2lSgNmWE5mwIiogFkn6p4kkm22F7mu2HVEyiKNv7277b9h9t/9b2Tn33t72L7btsL7N9qaS1a9ZNsD2/5vYWtq+w/ZTtZ2z/m+2/knS2pPeV70j9sYx99TR/eftY2w/bXmr7Ktub16wL28fZfqh8F+ffbbviIThA0rcjYmlEPCXp25I+OcjDCKDNqE3UJiBH1CZqU05ozoYg21tI2lfS/9Qs/pik90h6p+1dJZ0v6VOSRkv6vqSrytPaIyXNkvSfkjaS9BNJf1tnP2tK+oWkxyRtLWmspB9HxAOSjpN0a0SsFxF/McB9/4+kr0k6RNJm5TZ+3C9sfxXvZL27jNu7vO+WZXHcst4hKL9qb7/V9gZ14gF0AbWJ2gTkiNpEbcoJzdnQMqt8t+UWSTdJ+mrNuq+V74i8IOlYSd+PiDkRsTIiLpT0kqT3ll9rSfpWRLwcEZdLur3O/vaQtLmkf4qIFRHxYkTcUie2v8mSzo+IuyLiJUn/rOIdo61rYs6IiD9GxOOSZkvaWZIi4vGI+Ity+UB+Kel422+xvamkz5TL16mYG4D2ojYVqE1AXqhNBWpTRobDONrh5GMRcX2ddU/U/LyVpCNt/9+aZSNVFIyQtCAiombdY3W2uYWkxyLilSZy3VzSXX03ImK57WdUvIs0r1y8qCb+eUlVP6j7FUl/IeluFcXzXEm7SFrSRJ4AWkdtKlCbgLxQmwrUpoxw5mz4qC0aT0j6SvkuSt/XOhFxiaSFksb2G6dc7zT4E5K29MAflo0BltV6UkWxkyTZXlfFUIEFqQeSEhEvRMT0iBgbEdtIekbSnRGxstVtA2g7ahO1CcgRtYna1BM0Z8PTuZKOs/0eF9a1vZ/t9SXdKukVSZ+xPcL2QSpOww/kdyqK0hnlNta2vVe5brGK8coj69z3YklH297Z9igVQwnmRMS8Vh+c7bG2Ny8f23slfVHSjFa3C6DjqE0AckRtQtfQnA1DEXGHivHT/ybpWUkPSzqqXPdnSQeVt5+V9PeSrqiznZUqrvDzdkmPS5pfxkvSryTNlbTI9tMD3PcGFU/+n6ooVG+TdGiV/MsPti5v8MHWt0n6raQVki6UdFJEXFtl2wB6h9oEIEfUJnSTXz9EFgAAAADQC5w5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyMBAk+B1jG0uDQkMQRHhdFS+qE3AkPV0RLyl10m0gvoEDE31Xju1dObM9kds/8H2w7ZPamVbANBO1CcAkh7rdQL9UZsANNJ0c2Z7TUn/LmkfSe+UdJjtd7YrMQBoFvUJQI6oTQBSWjlztoekhyPi0XJ29B9LmtSetACgJdQnADmiNgFoqJXmbKykJ2puzy+XvY7tqbbvsH1HC/sCgMFI1idqE4Ae4LUTgIZauSDIQB9ie8OHViPiHEnnSHyoFUDXJOsTtQlAD/DaCUBDrZw5my9pi5rbb5X0ZGvpAEBbUJ8A5IjaBKChVpqz2yVta/svbY+UdKikq9qTFgC0hPoEIEfUJgANNT2sMSJesT1d0n9LWlPS+RExt22ZAUCTqE8AckRtApDiiO4NZWbcNDA0MQk1gEzdGRHjep1EK6hPwNDUkUmoAQAAAADtQXMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMjAiF4nAADA6mb06NHJmBtuuCEZs9NOOyVj5syZ03D9Mccck9zG3LlzkzEAgN7jzBkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMiAI6J7O7O7tzMAXRMR7nUOraA2odb222+fjPn5z3+ejNlmm23akU7Sueeem4w57rjjupBJlu6MiHG9TqIV1CdgaKr32okzZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGaM4AAAAAIAMjep0AVk8zZ85suH7GjBndSUTSxIkTkzE33nhj5xMBkL39998/GXPaaaclY7o1wXQVY8aM6XUKACoYMaLxy+7JkycntzF37txkzDrrrJOMWb58eTLm7rvvTsasWrUqGYPB4cwZAAAAAGSgpTNntudJWiZppaRXImJcO5ICgFZRnwDkiNoEoJF2DGucGBFPt2E7ANBu1CcAOaI2ARgQwxoBAAAAIAOtNmch6Vrbd9qeOlCA7am277B9R4v7AoDBaFifqE0AeoTXTgDqanVY414R8aTtTSRdZ/t/I+LXtQERcY6kcyTJdrS4PwCoqmF9ojYB6BFeOwGoq6UzZxHxZPl9iaSfSdqjHUkBQKuoTwByRG0C0EjTzZntdW2v3/ezpA9Lur9diQFAs6hPAHJEbQKQ0sqwxjGSfma7bzsXR8R/tSUr9FRqgmmpu5NMp8yePTsZw0TVww71aZg65ZRTGq4/6aSTktuoMoFrTh599NFep4DqqE3D2GWXXdZw/aRJk5Lb+O1vf5uMGT16dDJmu+22S8bstttuyZgqE1VjcJpuziLiUUnvbmMuANAW1CcAOaI2AUjhUvoAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABhwR3duZ3b2dYUATJkxIxlSZ1Hl1U2WC6SoTVWNgEeFe59AKalPzNt5442TMs88+m4xZuXJlMmbzzTdPxjzyyCMN148cOTK5jSqef/75ZEy3JrPeb7/9kjH/9V/Ddp7jOyNiXK+TaAX1qffWXXfdZMxNN92UjHn3uxtPcffcc88ltzF9+vRkzD777JOMmTx5cjLmnnvuScbsuuuuyRgMrN5rJ86cAQAAAEAGaM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADIzodQJon+E6wXQVVY4NgDc6+uijG64/44wzktu4+eabkzHTpk1Lxlx88cXJmHZMMl1lItiDDz44GXP11VcnY0aMSP8bfuqppxquv+uuu5LbANC8k046KRmzyy67tLyfKVOmJGNmzZqVjNliiy1azkWSRo0a1ZbtYHA4cwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgHnOhpDhOodZFaeeemqvUwCyU2VOsMMPP7zh+o033ji5jQMPPDAZM378+GTMRhttlIxJefHFF5MxBx10UDJm6dKlyZg11mjP+5/PP/98w/VLlixpy36A4ahKfTrllFPasq/UvJFV5jBrF9ttiUH7ceYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgEmoVxMTJkzodQoAhpgTTjghGfOBD3ygC5lIb3nLW7qyn9NPPz0ZM3v27GTMlClTkjHtmoR64cKFbdkOMNzstNNOyZjzzz8/GbNixYpkTJXJrKvUlm6JiLbEoP04cwYAAAAAGUg2Z7bPt73E9v01yzayfZ3th8rvG3Y2TQB4I+oTgBxRmwA0q8qZswskfaTfspMk3RAR20q6obwNAN12gahPAPJzgahNAJqQbM4i4teSlvZbPEnSheXPF0r6WHvTAoA06hOAHFGbADSr2QuCjImIhZIUEQttb1Iv0PZUSVOb3A8ADFal+kRtAtBlvHYCkNTxqzVGxDmSzpEk21z2BUAWqE0AckV9AoavZq/WuNj2ZpJUfl/SvpQAoCXUJwA5ojYBSGq2ObtK0pHlz0dKurI96QBAy6hPAHJEbQKQlBzWaPsSSRMkbWx7vqQZks6QdJntKZIel3RwJ5McDlKTTM+YMaM7iQCrEepTaw444IBep9B2TzzxRMP1F1xwQVv2M3fu3GRMlQlcbSdjrrrqqko5IR/Upu7Yc889G67/zne+k9zGBhtskIw588wzkzHXX399MgaoItmcRcRhdVZ9sM25AMCgUJ8A5IjaBKBZzQ5rBAAAAAC0Ec0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADyXnO0B2pSaZTk1R3W2ri1CqTrwLonHXWWScZs95663Uhk/ZZuXJlMmbq1KkN1y9atKgtudx2223JmCr5jhjBv2GgWYcffnjD9TvvvHNyG9dee20y5gtf+ELVlICWceYMAAAAADJAcwYAAAAAGaA5AwAAAIAM0JwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgNkvu6DKBNK5TTI91MycObPXKQBd9dGPfjQZs+OOO3Yhk/a58MILkzFVJpQFkL8pU6YkY44++uiG6x988MHkNo488shkTJUJ5YF24cwZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADNGcAAAAAkAHmOeuC2bNn9zqFtouIXqcwKFXmOWMuNAwlJ598cq9TGJQVK1YkY3iOAkPDWmutlYypMv/YyJEjG64/88wzk9tYvHhxMmYost2WGLQfZ84AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGmIS6RUNxgumhaMaMGS3HTJw4MbmNG2+8sWpKQEftuOOOyZhuTSb//PPPJ2M+/vGPJ2MWLFjQjnQA9NinP/3pZMxee+2VjJk1a1bD9ZdeemnVlIadKvW/W/8j8HqcOQMAAACADCSbM9vn215i+/6aZTNtL7B9d/m1b2fTBIA3oj4ByBG1CUCzqpw5u0DSRwZY/s2I2Ln8uqa9aQFAJReI+gQgPxeI2gSgCcnmLCJ+LWlpF3IBgEGhPgHIEbUJQLNa+czZdNv3lqfuN6wXZHuq7Tts39HCvgBgMJL1idoEoAd47QSgoWabs+9JepuknSUtlPSNeoERcU5EjIuIcU3uCwAGo1J9ojYB6DJeOwFIaqo5i4jFEbEyIlZJOlfSHu1NCwCaQ30CkCNqE4AqmmrObG9Wc/NASffXiwWAbqI+AcgRtQlAFclJqG1fImmCpI1tz5c0Q9IE2ztLCknzJH2qcynmbcKECb1OAV1S5XfNJNTdRX1aPcyfPz8Z84tf/KILmbTPEUcckYwZMSL5LzYrb3rTm5Ix06dPT8Z873vfS8YsX768Uk6rq+Fcm0aPHp2MOfHEE9uyr9Qk0y+88EJb9oPe22qrrZIxjz32WBcy6bzkf46IOGyAxed1IBcAGBTqE4AcUZsANKuVqzUCAAAAANqE5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADNGcAAAAAkIHVa4bMDFWZdJiJqoeGGTNmJGOq/D0wUTWGmyoTwU6aNCkZ89BDDyVjfv/731fKqVVrrbVWV/bTLnvvvXcy5qijjkrGHHLIIcmYyZMnJ2M+97nPNVw/e/bs5DaQp2uuuSYZM3bs2GTM3LlzkzG33nprw/Xbb799chtVjBo1Khkzfvz4ZMyf/vSnhuvXXnvt5Da23XbbZMzuu++ejKli5MiRyZhNN9204fpFixa1JZf3ve99yZgqj/uWW25JxrQr52Zx5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGTAEdG9ndnd21lGqkymyUTVQ8PEiROTMUNxEuqIcK9zaMVQrE1VavuqVau6kEn7rFixIhmTmuS1ynGZNWtWMuYTn/hEMmaDDTZIxlSxbNmyltZL6YliJWmNNbr3fu3ZZ5/dcP20adPatas7I2JcuzbWC6tbfapSV6o8D5cuXZqMSU1wX2Wyazv976ubr5VTcst38eLFDdc/8sgjyW3cd999yZiXX345GTN9+vRkTCpfSdp8882TMe1Q77UTZ84AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABmgOQMAAACADNCcAQAAAEAGmIQaTclpQsYqkzq3Y5LvU089NRkzc+bMlvezOmIS6vy0ayJYYLDmzp2bjDnhhBMarr/++uvblQ6TUHfZ5Zdfnow56KCDupBJtQmHH3744WTMiy++mIy5+OKLK+XUqioTyh977LHJmK233joZM2fOnGTMK6+80nD9+PHjk9vIadJsSdpss826kAmTUAMAAABA1mjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZGNHrBIBW3XTTTcmYiRMndiETIB9TpkxJxpxyyinJmDFjxjRc/+ijjya3sWDBgmRMlbmxqthhhx0art9xxx2T2xg9enQyZu21166cU6uWLFnScP2PfvSjLmUiXXLJJcmYhx56KBmzbNmydqSDDB1++OHJmNNOOy0Zc9hhhyVjUn+PqeeOJC1atCgZs7pZuXJlMqbK3K377bdfMmbp0qUN12+33XbJbbzrXe9Kxuy2227JmCqq1LBe48wZAAAAAGQg2ZzZ3sL2bNsP2J5r+/hy+Ua2r7P9UPl9w86nCwAFahOAXFGfADSrypmzVySdGBF/Jem9kqbZfqekkyTdEBHbSrqhvA0A3UJtApAr6hOApiSbs4hYGBF3lT8vk/SApLGSJkm6sAy7UNLHOpQjALwBtQlArqhPAJo1qAuC2N5a0i6S5kgaExELpaII2d6kzn2mSpraYp4AUBe1CUCuqE8ABqNyc2Z7PUk/lfTZiHjOdqX7RcQ5ks4ptxHNJAkA9VCbAOSK+gRgsCpdrdH2WiqKy0URcUW5eLHtzcr1m0lKX68UANqI2gQgV9QnAM2ocrVGSzpP0gMRcVbNqqskHVn+fKSkK9ufHgAMjNoEIFfUJwDNckTjs+W2x0u6WdJ9klaVi09WMXb6MklbSnpc0sER0XAmOk7NDx2pv5tuqjpMBJ0TEV3/JVCb0Cnbb799MuZ3v/tdMmbddddNxtx2223JmIkTJzZc/+c//zm5jWHszogY1+2dUp/QS4ceemgy5pBDDknGHHTQQe1IB3XUe+2U/MxZRNwiqd4Lrw+2khQANIvaBCBX1CcAzar0mTMAAAAAQGfRnAEAAABABmjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMpCc5wwAgOFkzTXXTMbY7Zl3fd68eckYJpkGMBhXXnllMuall17qQiZoBmfOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABhwR3duZ3b2doWkTJkxIxsyePbvziVTUrslg0byIWK1/CdQmYMi6MyLG9TqJVlCfgKGp3msnzpwBAAAAQAZozgAAAAAgAzRnAAAAAJABmjMAAAAAyADNGQAAAABkgOYMAAAAADJAcwYAAAAAGaA5AwAAAIAMjOh1AsjPjTfe2HJMlYms25ULAAAAMBRw5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGTAEdG9ndnd2xmArokI9zqHVlCbgCHrzogY1+skWkF9Aoameq+dkmfObG9he7btB2zPtX18uXym7QW27y6/9m130gBQD7UJQK6oTwCalTxzZnszSZtFxF2215d0p6SPSTpE0vKI+HrlnfHuDzAk9eLMGbUJQAU9OXNGfQKQUu+104gKd1woaWH58zLbD0ga2970AGBwqE0AckV9AtCsQV0QxPbWknaRNKdcNN32vbbPt71hnftMtX2H7TtaSxUABkZtApAr6hOAwah8QRDb60m6SdJXIuIK22MkPS0pJJ2u4vT9JxPb4NQ8MAT18oIg1CYADfT0giDUJwD1NH1BEEmyvZakn0q6KCKuKDe4OCJWRsQqSedK2qNdyQJAFdQmALmiPgFoRpWrNVrSeZIeiIizapZvVhN2oKT7258eAAyM2gQgV9QnAM2qcrXG8ZJulnSfpFXl4pMlHSZpZxWn5udJ+lT5AdhG2+LUPDAE9ehqjdQmACm9uloj9QlAQ/VeOzEJNYCWMQk1gEwxCTWALLX0mTMAAAAAQGfRnAEAAABABmjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZoDkDAAAAgAzQnAEAAABABmjOAAAAACADNGcAAAAAkAGaMwAAAADIAM0ZAAAAAGSA5gwAAAAAMkBzBgAAAAAZGNHl/T0t6bGa2xuXy1YX5NtZ5NtZncp3qw5ss9v61yaJ32+nkW9nkW9hKNYnfredRb6dRb6FurXJEdGB/VVj+46IGNezBAaJfDuLfDtrdcu311a340W+nUW+nbW65dtLq9uxIt/OIt/O6kW+DGsEAAAAgAzQnAEAAABABnrdnJ3T4/0PFvl2Fvl21uqWb6+tbseLfDuLfDtrdcu3l1a3Y0W+nUW+ndX1fHv6mTMAAAAAQKHXZ84AAAAAAKI5AwAAAIAs9Kw5s/0R23+w/bDtk3qVR1W259m+z/bdtu/odT792T7f9hLb99cs28j2dbYfKr9v2Msca9XJd6btBeUxvtv2vr3MsZbtLWzPtv2A7bm2jy+XZ3mMG+Sb7THOBbWpvahNnUVtGl6oT+1FfeocalMLufTiM2e215T0oKS/kTRf0u2SDouI33c9mYpsz5M0LiKynDjP9vslLZf0w4jYsVz2L5KWRsQZZRHfMCI+38s8+9TJd6ak5RHx9V7mNhDbm0naLCLusr2+pDslfUzSUcrwGDfI9xBleoxzQG1qP2pTZ1Gbhg/qU/tRnzqH2tS8Xp0520PSwxHxaET8WdKPJU3qUS5DQkT8WtLSfosnSbqw/PlCFX9kWaiTb7YiYmFE3FX+vEzSA5LGKtNj3CBfNEZtajNqU2dRm4YV6lObUZ86h9rUvF41Z2MlPVFze77yL84h6Vrbd9qe2utkKhoTEQul4o9O0iY9zqeK6bbvLU/dZ3Gquz/bW0vaRdIcrQbHuF++0mpwjHuI2tQd2T9vBpD984baNORRn7oj++fOALJ+7lCbBqdXzZkHWJb7Nf33iohdJe0jaVp5ahnt9T1Jb5O0s6SFkr7R02wGYHs9ST+V9NmIeK7X+aQMkG/2x7jHqE0YSPbPG2rTsEB9wkCyfu5QmwavV83ZfElb1Nx+q6Qne5RLJRHxZPl9iaSfqRhekLvF5RjavrG0S3qcT0MRsTgiVkbEKknnKrNjbHstFU/YiyLiinJxtsd4oHxzP8YZoDZ1R7bPm4Hk/ryhNg0b1KfuyPa5M5CcnzvUpub0qjm7XdK2tv/S9khJh0q6qke5JNlet/xwoGyvK+nDku5vfK8sXCXpyPLnIyVd2cNckvqerKUDldExtm1J50l6ICLOqlmV5TGul2/OxzgT1KbuyPJ5U0/Ozxtq07BCfeqOLJ879eT63KE2tZBLL67WKEkuLkX5LUlrSjo/Ir7Sk0QqsL2Nind8JGmEpItzy9f2JZImSNpY0mJJMyTNknSZpC0lPS7p4IjI4oOkdfKdoOK0cUiaJ+lTfeOSe832eEk3S7pP0qpy8ckqxiNnd4wb5HuYMj3GuaA2tRe1qbOoTcML9am9qE+dQ21qIZdeNWcAAAAAgNf0bBJqAAAAAMBraM4AAAAAIAM0ZwAAAACQAZozAAAAAMgAzRkAAAAAZIDmDAAAAAAyQHMGAAAAABn4/8AkFIR2MaKDAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "import matplotlib.pyplot as plt\n", "from random import shuffle\n", "\n", "'''\n", "Step 1: (same step)\n", "'''\n", "# Use data with only 4 and 9 as labels: which is hardest to classify\n", "label_1, label_2 = 4, 9\n", "\n", "# MNIST training data\n", "train_set = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "\n", "# Use data with two labels\n", "idx = (train_set.targets == label_1) + (train_set.targets == label_2)\n", "train_set.data = train_set.data[idx]\n", "train_set.targets = train_set.targets[idx]\n", "train_set.targets[train_set.targets == label_1] = -1\n", "train_set.targets[train_set.targets == label_2] = 1\n", "\n", "# MNIST testing data\n", "test_set = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor())\n", "\n", "# Use data with two labels\n", "idx = (test_set.targets == label_1) + (test_set.targets == label_2)\n", "test_set.data = test_set.data[idx]\n", "test_set.targets = test_set.targets[idx]\n", "test_set.targets[test_set.targets == label_1] = -1\n", "test_set.targets[test_set.targets == label_2] = 1\n", " \n", "\n", "'''\n", "Step 2: (same step)\n", "'''\n", "class LR(nn.Module) :\n", " '''\n", " Initialize model\n", " input_dim : dimension of given input data\n", " '''\n", " # MNIST data is 28x28 images\n", " def __init__(self, input_dim=28*28) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, 1, bias=False)\n", "\n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " return self.linear(x.float().view(-1, 28*28))\n", "\n", "'''\n", "Step 3: (same step)\n", "'''\n", "model = LR() # Define a Neural Network Model\n", "\n", "def logistic_loss(output, target):\n", " return -torch.nn.functional.logsigmoid(target*output)\n", "\n", "loss_function = logistic_loss # Specify loss function\n", "optimizer = torch.optim.SGD(model.parameters(), lr=255*1e-4) # specify SGD with learning rate\n", "\n", "\n", "\n", "'''\n", "Step 4: Train model with SGD (LOOK HERE)\n", "'''\n", "train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "# Train the model for 3 epochs\n", "for epoch in range(3) :\n", " for image, label in train_loader :\n", " # Clear previously computed gradient\n", " optimizer.zero_grad()\n", "\n", " # then compute gradient with forward and backward passes\n", " train_loss = loss_function(model(image), label.float())\n", " train_loss.backward()\n", "\n", " # perform SGD step (parameter update)\n", " optimizer.step()\n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end-start}\")\n", "\n", "\n", "'''\n", "Step 5: (same step)\n", "'''\n", "test_loss, correct = 0, 0\n", "misclassified_ind = []\n", "correct_ind = []\n", "\n", "# Test data\n", "test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)\n", "# no need to shuffle test data\n", "\n", "# Evaluate accuracy using test data\n", "for ind, (image, label) in enumerate(test_loader) :\n", "\n", " # Forward pass\n", " output = model(image)\n", "\n", " # Calculate cumulative loss\n", " test_loss += loss_function(output, label.float()).item()\n", "\n", " # Make a prediction\n", " if output.item() * label.item() >= 0 : \n", " correct += 1\n", " correct_ind += [ind]\n", " else:\n", " misclassified_ind += [ind]\n", "\n", "# Print out the results\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_loader), correct, len(test_loader),\n", " 100. * correct / len(test_loader)))\n", "\n", "'''\n", "Step 6: (same step)\n", "''' \n", "# Misclassified images\n", "shuffle(misclassified_ind)\n", "fig = plt.figure(1, figsize=(15, 6))\n", "fig.suptitle('Misclassified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[misclassified_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[misclassified_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_2))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_1))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()\n", "\n", "# Correctly classified images\n", "shuffle(correct_ind)\n", "fig = plt.figure(2, figsize=(15, 6))\n", "fig.suptitle('Correctly-classified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[correct_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[correct_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_1))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_2))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using batch update" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time ellapsed in training is: 5.357270240783691\n", "[Test set] Average loss: 0.1748, Accuracy: 1895/1991 (95.18%)\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "import matplotlib.pyplot as plt\n", "from random import shuffle\n", "'''\n", "Step 1: (same step)\n", "'''\n", "# Use data with only 4 and 9 as labels: which is hardest to classify\n", "label_1, label_2 = 4, 9\n", "\n", "# MNIST training data\n", "train_set = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "\n", "# Use data with two labels\n", "idx = (train_set.targets == label_1) + (train_set.targets == label_2)\n", "train_set.data = train_set.data[idx]\n", "train_set.targets = train_set.targets[idx]\n", "train_set.targets[train_set.targets == label_1] = -1\n", "train_set.targets[train_set.targets == label_2] = 1\n", "\n", "# MNIST testing data\n", "test_set = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor())\n", "\n", "# Use data with two labels\n", "idx = (test_set.targets == label_1) + (test_set.targets == label_2)\n", "test_set.data = test_set.data[idx]\n", "test_set.targets = test_set.targets[idx]\n", "test_set.targets[test_set.targets == label_1] = -1\n", "test_set.targets[test_set.targets == label_2] = 1\n", " \n", "\n", "'''\n", "Step 2: (same step)\n", "'''\n", "class LR(nn.Module) :\n", " '''\n", " Initialize model\n", " input_dim : dimension of given input data\n", " '''\n", " # MNIST data is 28x28 images\n", " def __init__(self, input_dim=28*28) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, 1, bias=False)\n", "\n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " #B = batchsize\n", " # x input has dim [B,1,28,28]\n", " # convert this to dim [B,784]\n", " # after linear operation, output is dim [B,1]\n", " return self.linear(x.float().view(-1, 28*28))\n", "\n", "'''\n", "Step 3: Create the model, specify loss function and optimizer. (LOOK HERE)\n", "'''\n", "model = LR() # Define a Neural Network Model\n", "\n", "def logistic_loss(output, target):\n", " #output has dim [B,1]\n", " #target has dim [B]\n", " #dimensions as is don't match!\n", " #convert output dim to [B]\n", " #conert target dim to [B]\n", " #elementwise product (* is elementwise product in Python) \n", " #after logsigmoid, dim is [B]\n", " #after mean, dim is [1]\n", " return torch.mean(-torch.nn.functional.logsigmoid(target.reshape(-1)*output.reshape(-1)))\n", "\n", "loss_function = logistic_loss # Specify loss function\n", "optimizer = torch.optim.SGD(model.parameters(), lr=255*1e-4) # specify SGD with learning rate\n", "\n", "\n", "\n", "'''\n", "Step 4: Train model with SGD (LOOK HERE)\n", "'''\n", "train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "# Train the model (for 3 epochs)\n", "for epoch in range(3) :\n", " for images, labels in train_loader :\n", " # Clear previously computed gradient\n", " optimizer.zero_grad()\n", "\n", " # then compute gradient with forward and backward passes\n", " train_loss = loss_function(model(images), labels.float())\n", " train_loss.backward()\n", "\n", " # perform SGD step (parameter update)\n", " optimizer.step()\n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end - start}\")\n", "\n", "\n", "'''\n", "Step 5: (same step)\n", "'''\n", "test_loss, correct = 0, 0\n", "misclassified_ind = []\n", "correct_ind = []\n", "\n", "# Test data\n", "test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)\n", "# no need to shuffle test data\n", "\n", "# Evaluate accuracy using test data\n", "for ind, (image, label) in enumerate(test_loader) :\n", "\n", " # Forward pass\n", " output = model(image)\n", "\n", " # Calculate cumulative loss\n", " test_loss += loss_function(output, label.float()).item()\n", "\n", " # Make a prediction\n", " if output.item() * label.item() >= 0 : \n", " correct += 1\n", " correct_ind += [ind]\n", " else:\n", " misclassified_ind += [ind]\n", "\n", "# Print out the results\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_loader), correct, len(test_loader),\n", " 100. * correct / len(test_loader)))\n", "\n", "'''\n", "Step 6: (same step)\n", "''' \n", "# Misclassified images\n", "shuffle(misclassified_ind)\n", "fig = plt.figure(1, figsize=(15, 6))\n", "fig.suptitle('Misclassified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[misclassified_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[misclassified_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_2))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_1))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()\n", "\n", "# Correctly classified images\n", "shuffle(correct_ind)\n", "fig = plt.figure(2, figsize=(15, 6))\n", "fig.suptitle('Correctly-classified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[correct_ind[k]].cpu().numpy().astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[correct_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_1, label_1))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(label_2, label_2))\n", " plt.imshow(image, cmap='gray')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CIFAR10 (two classes) with logistic regression trained with random permutation cyclic batch SGD" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Time ellapsed in training is: 12.12409257888794\n", "[Test set] Average loss: 0.5001, Accuracy: 1536/2000 (76.80%)\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "import matplotlib.pyplot as plt\n", "from random import shuffle\n", "\n", "\n", "'''\n", "Step 1: Prepare dataset\n", "'''\n", "classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", "label_1 = classes.index('plane')\n", "label_2 = classes.index('car')\n", "\n", "train_set = datasets.CIFAR10(root='./cifar_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "\n", "# Use data with two labels\n", "train_set.targets = torch.tensor(train_set.targets)\n", "idx = (train_set.targets == label_1) + (train_set.targets == label_2)\n", "train_set.data = train_set.data[idx]\n", "train_set.targets = train_set.targets[idx]\n", "train_set.targets[train_set.targets == label_1] = -1\n", "train_set.targets[train_set.targets == label_2] = 1\n", "\n", "\n", "test_set = datasets.CIFAR10(root='./cifar_data/', train=False, transform=transforms.ToTensor(), )\n", "\n", "# Use data with two labels\n", "test_set.targets = torch.tensor(test_set.targets)\n", "idx = (test_set.targets == label_1) + (test_set.targets == label_2)\n", "test_set.data = test_set.data[idx]\n", "test_set.targets = test_set.targets[idx]\n", "test_set.targets[test_set.targets == label_1] = -1\n", "test_set.targets[test_set.targets == label_2] = 1\n", "\n", "\n", "'''\n", "Step 2: Define the neural network class.\n", "'''\n", "class LR(nn.Module) :\n", " '''\n", " Initialize model\n", " input_dim : dimension of given input data\n", " '''\n", " # CIFAR-10 data is 32*32 images with 3 RGB channels\n", " def __init__(self, input_dim=3*32*32) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, 1, bias=False)\n", " \n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " #reshape input into dim [B,3*32*32]\n", " #output has dim [B,1]\n", " x = self.linear(x.float().view(-1, 3*32*32)) # Flattens the given data(tensor)\n", " return x\n", " \n", "\n", "'''\n", "Step 3: Create the model, specify loss function and optimizer.\n", "'''\n", "model = LR() # Define Neural Network Models \n", "\n", "def logistic_loss(output, target):\n", " return torch.mean(-torch.nn.functional.logsigmoid(target.reshape(-1)*output.reshape(-1)))\n", "loss_function = logistic_loss # Specify loss function\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) # specify SGD with learning rate\n", "\n", "\n", "\n", "'''\n", "Step 4: Train model with SGD\n", "'''\n", "# Use DataLoader class\n", "train_loader = DataLoader(dataset=train_set, batch_size=1024, shuffle=True)\n", "import time\n", "start = time.time()\n", "# Train the model\n", "for epoch in range(10) :\n", " for images, labels in train_loader :\n", "\n", " # Clear previously computed gradient\n", " optimizer.zero_grad()\n", " # then compute gradient with forward and backward passes\n", " train_loss = loss_function(model(images), labels.float())\n", " train_loss.backward()\n", "\n", " # perform SGD step (parameter update)\n", " optimizer.step()\n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end - start}\")\n", " \n", "\n", "'''\n", "Step 5: Test model (Evaluate the accuracy)\n", "'''\n", "test_loss, correct = 0, 0\n", "misclassified_ind = []\n", "correct_ind = []\n", "\n", "# Test data\n", "test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)\n", "\n", "# Evaluate accuracy using test data\n", "for ind, (image, label) in enumerate(test_loader) :\n", "\n", " # Forward pass\n", " output = model(image)\n", " \n", " # Calculate cumulative loss\n", " test_loss += loss_function(output, label.float()).item()\n", "\n", " # Make a prediction\n", " if output.item() * label.item() >= 0 :\n", " correct += 1\n", " correct_ind += [ind]\n", " else:\n", " misclassified_ind += [ind]\n", "\n", "# Print out the results\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_loader), correct, len(test_loader),\n", " 100. * correct / len(test_loader)))\n", "\n", "\n", "\n", "'''\n", "Step 6: Show some incorrectly classified images and some correctly classified ones\n", "''' \n", "# Misclassified images\n", "shuffle(misclassified_ind)\n", "fig = plt.figure(1, figsize=(15, 6))\n", "fig.suptitle('Misclassified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[misclassified_ind[k]].astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[misclassified_ind[k]]\n", " \n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_1], classes[label_2]))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_2], classes[label_1]))\n", " plt.imshow(image)\n", "plt.show()\n", "\n", "# Correctly classified images\n", "shuffle(correct_ind)\n", "fig = plt.figure(2, figsize=(15, 6))\n", "fig.suptitle('Correctly-classified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[correct_ind[k]].astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[correct_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_1], classes[label_1]))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_2], classes[label_2]))\n", " plt.imshow(image)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CIFAR10 with Multilayer perceptron" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Time ellapsed in training is: 213.10205554962158\n", "[Test set] Average loss: 0.3722, Accuracy: 1634/2000 (81.70%)\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "import matplotlib.pyplot as plt\n", "from numpy import linspace\n", "\n", "\n", "'''\n", "Step 1: (same step)\n", "'''\n", "classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", "label_1 = classes.index('plane')\n", "label_2 = classes.index('car')\n", "\n", "train_set = datasets.CIFAR10(root='./cifar_data/', train=True, transform=transforms.ToTensor(), download=True)\n", "\n", "# Use data with two labels\n", "train_set.targets = torch.tensor(train_set.targets)\n", "idx = (train_set.targets == label_1) + (train_set.targets == label_2)\n", "train_set.data = train_set.data[idx]\n", "train_set.targets = train_set.targets[idx]\n", "train_set.targets[train_set.targets == label_1] = -1\n", "train_set.targets[train_set.targets == label_2] = 1\n", "\n", "\n", "test_set = datasets.CIFAR10(root='./cifar_data/', train=False, transform=transforms.ToTensor(), )\n", "\n", "# Use data with two labels\n", "test_set.targets = torch.tensor(test_set.targets)\n", "idx = (test_set.targets == label_1) + (test_set.targets == label_2)\n", "test_set.data = test_set.data[idx]\n", "test_set.targets = test_set.targets[idx]\n", "test_set.targets[test_set.targets == label_1] = -1\n", "test_set.targets[test_set.targets == label_2] = 1\n", "\n", "\n", "'''\n", "Step 2: Define the neural network class (LOOK HERE)\n", "'''\n", "class MLP4(nn.Module) :\n", " '''\n", " Initialize model\n", " input_dim : dimension of given input data\n", " '''\n", " # CIFAR-10 data is 32*32 images with 3 RGB channels\n", " def __init__(self, input_dim=3*32*32) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, input_dim//2, bias=True)\n", " self.linear2 = nn.Linear(input_dim//2, input_dim//4, bias=True)\n", " self.linear3 = nn.Linear(input_dim//4, input_dim//8, bias=True)\n", " self.linear4 = nn.Linear(input_dim//8, 1, bias=True)\n", " \n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " x = x.float().view(-1, 3*32*32)\n", " x = nn.functional.relu(self.linear(x))\n", " x = nn.functional.relu(self.linear2(x))\n", " x = nn.functional.relu(self.linear3(x))\n", " x = self.linear4(x)\n", " return x\n", " \n", " \n", "'''\n", "Step 3: Create the model, specify loss function and optimizer (LOOK HERE)\n", "'''\n", "model = MLP4() # Define Neural Network Models \n", "\n", "def logistic_loss(output, target):\n", " return torch.mean(-torch.nn.functional.logsigmoid(target.reshape(-1)*output.reshape(-1)))\n", "loss_function = logistic_loss # Specify loss function\n", "\n", "optimizer = torch.optim.SGD(model.parameters(), lr=3*1e-2) # specify SGD with learning rate\n", "\n", "\n", "\n", "'''\n", "Step 4: (same step)\n", "'''\n", "# Use DataLoader class\n", "train_loader = DataLoader(dataset=train_set, batch_size=1024, shuffle=True)\n", "import time\n", "start = time.time()\n", "# Train the model\n", "for epoch in range(100) :\n", " for images, labels in train_loader :\n", "\n", " # Clear previously computed gradient\n", " optimizer.zero_grad()\n", " # then compute gradient with forward and backward passes\n", " train_loss = loss_function(model(images), labels.float())\n", " train_loss.backward()\n", "\n", " # perform SGD step (parameter update)\n", " optimizer.step()\n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end - start}\")\n", " \n", "\n", "\n", "'''\n", "Step 5: (same step)\n", "'''\n", "test_loss, correct = 0, 0\n", "misclassified_ind = []\n", "correct_ind = []\n", "\n", "# Test data\n", "test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)\n", "\n", "# Evaluate accuracy using test data\n", "for ind, (image, label) in enumerate(test_loader) :\n", "\n", " # Forward pass\n", " output = model(image)\n", " \n", " # Calculate cumulative loss\n", " test_loss += loss_function(output, label.float()).item()\n", "\n", " # Make a prediction\n", " if output.item() * label.item() >= 0 :\n", " correct += 1\n", " correct_ind += [ind]\n", " else:\n", " misclassified_ind += [ind]\n", "\n", "# Print out the results\n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_loader), correct, len(test_loader),\n", " 100. * correct / len(test_loader)))\n", "\n", "\n", "\n", "'''\n", "Step 6: (same step)\n", "''' \n", "# Misclassified images\n", "shuffle(misclassified_ind)\n", "fig = plt.figure(1, figsize=(15, 6))\n", "fig.suptitle('Misclassified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[misclassified_ind[k]].astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[misclassified_ind[k]]\n", " \n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_1], classes[label_2]))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_2], classes[label_1]))\n", " plt.imshow(image)\n", "plt.show()\n", "\n", "# Correctly classified images\n", "shuffle(correct_ind)\n", "fig = plt.figure(2, figsize=(15, 6))\n", "fig.suptitle('Correctly-classified Figures', fontsize=16)\n", "\n", "for k in range(3) :\n", " image = test_set.data[correct_ind[k]].astype('uint8')\n", " ax = fig.add_subplot(1, 3, k+1)\n", " true_label = test_set.targets[correct_ind[k]]\n", "\n", " if true_label == -1 :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_1], classes[label_1]))\n", " else :\n", " ax.set_title('True Label: {}\\nPrediction: {}'.format(classes[label_2], classes[label_2]))\n", " plt.imshow(image)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP model parameters\n", "torch.Size([1536, 3072])\n", "torch.Size([1536])\n", "torch.Size([768, 1536])\n", "torch.Size([768])\n", "torch.Size([384, 768])\n", "torch.Size([384])\n", "torch.Size([1, 384])\n", "torch.Size([1])\n", "MLP has a total of 6196225 parameters.\n" ] } ], "source": [ "print(\"MLP model parameters\")\n", "param_num = 0\n", "for parameter in model.parameters():\n", " print(parameter.shape)\n", " param_num += parameter.numel()\n", "print(f\"MLP has a total of {param_num} parameters.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Softmax regression for MNIST" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using downloaded and verified file: ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz\n", "Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw\n", "Using downloaded and verified file: ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz\n", "Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw\n", "Using downloaded and verified file: ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz\n", "Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw\n", "Using downloaded and verified file: ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n", "Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw\n", "Processing...\n", "Done!\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/ernestryu/opt/anaconda3/lib/python3.8/site-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /Users/distiller/project/conda/conda-bld/pytorch_1603740477510/work/torch/csrc/utils/tensor_numpy.cpp:141.)\n", " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Time ellapsed in training is: 20.37763786315918\n", "[Test set] Average loss: 0.8003, Accuracy: 8471/10000 (84.71%)\n", "\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "import matplotlib.pyplot as plt\n", "'''\n", "Step 1: Load the entire MNIST dataset (LOOK HERE)\n", "'''\n", "\n", "train_set = datasets.MNIST(root='./mnist_data/',\n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_set = datasets.MNIST(root='./mnist_data/',\n", " train=False,\n", " transform=transforms.ToTensor())\n", "\n", "\n", "'''\n", "Step 2: Since there are 10 classes, the output should be 10 (LOOK HERE)\n", "'''\n", "class softmax(nn.Module) :\n", " def __init__(self, input_dim=28*28) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, 10, bias=True)\n", "\n", " def forward(self, x) :\n", " return self.linear(x.float().view(-1, 28*28))\n", "\n", "'''\n", "Step 3: Create the model, specify loss function and optimizer (LOOK HERE)\n", "'''\n", "model = softmax()\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n", "\n", "\n", "'''\n", "Step 4: (same step)\n", "'''\n", "train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(5) :\n", " for images, labels in train_loader :\n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", " optimizer.step()\n", "end = time.time()\n", "print(f\"Time ellapsed in training is: {end - start}\")\n", "\n", "\n", "'''\n", "Step 5: Test model (Evaluate the accuracy)\n", "'''\n", "test_loss, correct = 0, 0\n", "\n", "test_loader = DataLoader(dataset=test_set, batch_size=1, shuffle=False)\n", "\n", "for ind, (image, label) in enumerate(test_loader) :\n", " output = model(image)\n", " test_loss += loss_function(output, label).item()\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(label.view_as(pred)).sum().item()\n", "\n", " \n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /len(test_loader), correct, len(test_loader),\n", " 100. * correct / len(test_loader)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# GPU computing on PyTorch\n", "\n", "Check availability of GPU on system" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "False\n", "0\n", "1.7.0\n" ] } ], "source": [ "import torch\n", "\n", "print(torch.cuda.is_available())\n", "print(torch.cuda.device_count())\n", "\n", "print(torch.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use `.to(device)` to create a copy of a tensor on the GPU" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n", "cpu\n" ] } ], "source": [ "t = torch.tensor([\n", " [1,1,1,1],\n", " [2,2,2,2],\n", " [3,3,3,3]\n", "], dtype=torch.float32)\n", "\n", "device = \"cpu\"\n", "# device = \"cuda:0\"\n", "#device = \"cuda:5\" #error if you have fewer than 6 GPUs\n", "\n", "t_dev = t.to(device)\n", "\n", "print(t.device)\n", "print(t_dev.device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Perform power iteration" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Time ellapsed in loop is: 1.3693819046020508\n", "86.06098\n" ] } ], "source": [ "import torch\n", "import numpy as np\n", "N = 8192\n", "A = torch.normal(0,1/np.sqrt(N),(N,N)) #8*8192^2=512Mb data\n", "x = torch.normal(0.0, 1.0,(N,1))\n", "\n", "\n", "device = \"cpu\"\n", "# device = \"cuda:0\"\n", "A = A.to(device)\n", "x = x.to(device) #error if A is sent to GPU but x is not sent to GPU\n", "\n", "import time\n", "start = time.time()\n", "for _ in range(100):\n", " x = A@x # matrix-vector product\n", "end = time.time()\n", "print(f\"Time ellapsed in loop is: {end - start}\")\n", "\n", "x = x.to(\"cpu\")\n", "print(np.linalg.norm(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Calssification via MLP on GPU\n", "\n", "(GPU does not provide a speedup as the model is small.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.optim import Optimizer\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import transforms\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "'''\n", "Step 1\n", "'''\n", "train_dataset = datasets.MNIST(root='./mnist_data/',\n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='./mnist_data/',\n", " train=False,\n", " transform=transforms.ToTensor())\n", "\n", "'''\n", "Step 2: Define the neural network class (LOOK HERE)\n", "'''\n", "class MLP4(nn.Module) :\n", " '''\n", " Initialize model\n", " input_dim : dimension of given input data\n", " '''\n", " def __init__(self, input_dim=28*28) :\n", " super().__init__()\n", " self.linear = nn.Linear(input_dim, input_dim//2, bias=True)\n", " self.linear2 = nn.Linear(input_dim//2, input_dim//4, bias=True)\n", " self.linear3 = nn.Linear(input_dim//4, input_dim//8, bias=True)\n", " self.linear4 = nn.Linear(input_dim//8, 10, bias=True)\n", " \n", " ''' forward given input x '''\n", " def forward(self, x) :\n", " x = x.float().view(-1, 28*28)\n", " x = nn.functional.relu(self.linear(x))\n", " x = nn.functional.relu(self.linear2(x))\n", " x = nn.functional.relu(self.linear3(x))\n", " x = self.linear4(x)\n", " return x\n", "'''\n", "Step 3 Instantiate model and send it to device (LOOK HERE)\n", "'''\n", "model = MLP4().to(device)\n", "loss_function = torch.nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)\n", "\n", "\n", "'''\n", "Step 4 Load batch, send it to device, and perform SGD (LOOK HERE)\n", "'''\n", "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)\n", "\n", "import time\n", "start = time.time()\n", "for epoch in range(1) :\n", " for images, labels in train_loader :\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " optimizer.zero_grad()\n", " train_loss = loss_function(model(images), labels)\n", " train_loss.backward()\n", "\n", " optimizer.step()\n", "end = time.time()\n", "print(\"Time ellapsed in training is: {}\".format(end - start))\n", "\n", "\n", "\n", "'''\n", "Step 5 Load batch, send it to device, and peform test (LOOK HERE)\n", "'''\n", "test_loss, correct, total = 0, 0, 0\n", "\n", "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1024, shuffle=False)\n", "\n", "for images, labels in test_loader :\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " output = model(images)\n", " test_loss += loss_function(output, labels).item()\n", "\n", " pred = output.max(1, keepdim=True)[1]\n", " correct += pred.eq(labels.view_as(pred)).sum().item()\n", " \n", " total += labels.size(0)\n", " \n", "print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\\n'.format(\n", " test_loss /total, correct, total,\n", " 100. * correct / total))\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }